PGMax Reference Documentation
PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g.
vmap
for processing batches of models/samples,grad
for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.
See our blog post and companion paper for more details.
Getting Started:
Developer Documentation:
API Documentation