PGMax API reference

This page contains the list of project’s modules


A sub-package defining different types of factors.


A module containing the base classes for factors in a factor graph.


Defines an enumeration factor.


Defines a logical factor.


Defines a pool factor.


class pgmax.factor.Wiring(var_states_for_edges: Union[numpy.ndarray, jax.Array])

Wiring for factors.


Array of shape (num_edge_states, 3) For each edge state: var_states_for_edges[ii, 0] contains its global variable state index var_states_for_edges[ii, 1] contains its global edge index var_states_for_edges[ii, 2] contains its global factor index

class pgmax.factor.Factor(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray)

A factor.


List of variables connected by the Factor. Each variable is represented by a tuple (variable hash, variable num_states)


Array of log potentials


NotImplementedError – If compile_wiring is not implemented


class pgmax.factor.EnumWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], factor_configs_edge_states: Union[numpy.ndarray, jax.Array])

Wiring for EnumFactors.


Array of shape (num_factor_configs, 2) factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices factor_configs_edge_states[ii, 0] contains the global EnumFactor config index factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index Both indices only take into account the EnumFactors of the FactorGraph


Number of valid configurations for this wiring


Number of factors covered by this wiring

class pgmax.factor.EnumFactor(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray, factor_configs: numpy.ndarray)

An enumeration factor.


Array of shape (num_val_configs, num_variables) An array containing an explicit enumeration of all valid configurations


Array of shape (num_val_configs,) An array containing the log of the potential value for each valid configuration


ValueError – If: (1) The dtype of the factor_configs array is not int (2) The dtype of the potential array is not float (3) factor_configs does not have the correct shape (4) The potential array does not have the correct shape (5) The factor_configs array contains invalid values


class pgmax.factor.LogicalWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], parents_edge_states: Union[numpy.ndarray, jax.Array], children_edge_states: Union[numpy.ndarray, jax.Array], edge_states_offset: int)

Wiring for LogicalFactors.


Array of shape (num_parents, 2) parents_edge_states[ii, 0] contains the global LogicalFactor index parents_edge_states[ii, 1] contains the message index of the parent variable’s relevant state The message index of the parent variable’s other state is parents_edge_states[ii, 1] + edge_states_offset Both indices only take into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph


Array of shape (num_factors,) children_edge_states[ii] contains the message index of the child variable’s relevant state The message index of the child variable’s other state is children_edge_states[ii] + edge_states_offset Only takes into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph


Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1


ValueError – If: (1) The are no num_logical_factors different factor indices (2) There is a factor index higher than num_logical_factors - 1 (3) The edge_states_offset is not 1 or -1

class pgmax.factor.LogicalFactor(variables: List[Tuple[int, int]])

A logical OR/AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.


Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1


ValueError – If: (1) There are less than 2 variables (2) The variables are not all binary

class pgmax.factor.ORFactor(variables: List[Tuple[int, int]])

An OR factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.

An OR factor is defined as: F(p1, p2, …, pn, c) = 0 <=> c = OR(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.


Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1.

class pgmax.factor.ANDFactor(variables: List[Tuple[int, int]])

An AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.

An AND factor is defined as:

F(p1, p2, …, pn, c) = 0 <=> c = AND(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.


Offset to go from a variable’s relevant state to its other state For ANDFactors the edge_states_offset is -1.


class pgmax.factor.PoolWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], pool_choices_edge_states: Union[numpy.ndarray, jax.Array], pool_indicators_edge_states: Union[numpy.ndarray, jax.Array])

Wiring for PoolFactors.


Array of shape (num_pool_choices, 2) pool_choices_edge_states[ii, 0] contains the global PoolFactor index pool_choices_edge_states[ii, 1] contains the message index of the pool choice variable’s state 0. The message index of the pool choice variable’s state 1 is pool_choices_edge_states[ii, 1] + 1 Both indices only take into account the PoolFactors of the FactorGraph


Array of shape (num_pool_factors,) pool_indicators_edge_states[ii] contains the message index of the pool indicator variable’s state 0 The message index of the pool indicator variable’s state 1 is pool_indicators_edge_states[ii] + 1 Only takes into account the PoolFactors of the FactorGraph


ValueError – If: (1) The are no num_pool_factors different factor indices (2) There is a factor index higher than num_pool_factors - 1

class pgmax.factor.PoolFactor(variables: List[Tuple[int, int]])

A Pool factor of the form (pc1, …,pcn, pi) where (pc1,…,pcn) are the pool choices and pi is the pool indicator.

A Pool factor is defined as: F(pc1, …,pcn, pi) = 0 <=> (pc1=…=pcn=pi=0) OR (pi=1 AND pc1 +…+ pcn=1) F(pc1, …,pcn, pi) = -inf o.w.

i.e. either (a) all the variables are set to 0, or (b) the pool indicator variable is set to 1 and exactly one of the pool choices variables is set to 1

Note: placing the pool indicator at the end allows us to reuse our existing infrastucture for wiring logical factors


A sub-package containing functions to represent a factor graph.


A module containing the core class to build a factor graph.


class pgmax.fgraph.FactorGraphState(variable_groups: Sequence[pgmax.vgroup.vgroup.VarGroup], vars_to_starts: Mapping[Tuple[int, int], int], num_var_states: int, total_factor_num_states: int, factor_type_to_msgs_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_type_to_potentials_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_group_to_potentials_starts: OrderedDict[pgmax.fgroup.fgroup.FactorGroup, int], log_potentials: numpy.ndarray, evidence_to_vars: numpy.ndarray, wiring: OrderedDict[Type[pgmax.factor.factor.Factor], pgmax.factor.factor.Wiring])



VarGroups in the FactorGraph.


Maps variables to their starting indices in the flat evidence array. flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] contains evidence to the variable.


Total number of variable states.


Size of the flat ftov messages array.


Maps factors types to their start and end indices in the flat ftov messages.


Maps factor types to their start and end indices in the flat log potentials.


Maps factor groups to their starting indices in the flat log potentials.


Flat log potentials array concatenated for each factor type.


Maps the evidence entries to their variable indices


Wiring derived for each factor type.

class pgmax.fgraph.FactorGraph(variable_groups: Union[pgmax.vgroup.vgroup.VarGroup, Sequence[pgmax.vgroup.vgroup.VarGroup]])

Class for representing a factor graph.

Factors in a graph are clustered in factor groups, which are grouped according to their factor types.


A single VarGroup or a list of VarGroups.


A sub-package defining different types of groups of factors.


Defines EnumFactorGroup and PairwiseFactorGroup.


A module containing the base classes for factor groups in a factor graph.


Defines LogicalFactorGroup and its two children, ORFactorGroup and ANDFactorGroup.


Defines PoolFactorGroup.


class pgmax.fgroup.FactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of Factors.


A list of list of variables. Each list within the outer list contains the variables connected to a Factor. The same variable can be connected to multiple Factors.


Optional array containing an explicit enumeration of all valid configurations


Array of log potentials.


Factor type shared by all the Factors in the FactorGroup.


ValueError – if the FactorGroup does not contain a Factor

class pgmax.fgroup.SingleFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], single_factor: pgmax.factor.factor.Factor)

Class to represent a FactorGroup with a single factor.

For internal use only. Should not be directly used to add FactorGroups to a factor graph.


the single factor in the SingleFactorGroup


class pgmax.fgroup.EnumFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], factor_configs: numpy.ndarray, log_potentials: Optional[Union[numpy.ndarray, jax.Array]] = None)

Class to represent a group of EnumFactors.

All factors in the group are assumed to have the same set of valid configurations. The associated log potentials can however be different across factors.


Array of shape (num_val_configs, num_variables) containing explicit enumeration of all valid configurations


Optional 1D array of shape (num_val_configs,) or 2D array of shape (num_factors, num_val_configs). If 1D, the log potentials are copied for each factor of the group. If 2D, it specifices the log potentials of each factor. If None, the log potential are initialized to uniform 0.


Factor type shared by all the Factors in the FactorGroup.


ValueError if

  1. The specified log_potentials is not of the expected shape. (2) The dtype of the potential array is not float

class pgmax.fgroup.PairwiseFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], log_potential_matrix: Optional[Union[numpy.ndarray, jax.Array]] = None)

Class to represent a group of EnumFactors where each factor connects to two different variables.

All factors in the group are assumed to be such that all possible configurations of the two variables are valid. The associated log potentials can however be different across factors.


Optional 2D array of shape (num_states1, num_states2) or 3D array of shape (num_factors, num_states1, num_states2) where num_states1 and num_states2 are the number of states of the first and second variables involved in each factor. If 2D, the log potentials are copied for each factor of the group. If 3D, it specifies the log potentials of each factor. If None, the log potential are initialized to uniform 0.


Factor type shared by all the Factors in the FactorGroup.


ValueError if

  1. The specified log_potential_matrix is not a 2D or 3D array. (2) The dtype of the potential array is not float (3) Some pairwise factors connect to less or more than 2 variables. (4) The specified log_potential_matrix does not match the number of factors. (5) The specified log_potential_matrix does not match the number of variable states of the variables in the factors.


class pgmax.fgroup.LogicalFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of LogicalFactors.

All factors in the group are assumed to have the same edge_states_offset. Consequently, the factors are all ORFactors or ANDFactors.


Offset to go from a variable’s relevant state to its other state. For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1

class pgmax.fgroup.ORFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of ORFactors.


Offset to go from a variable’s relevant state to its other state. For ORFactors the edge_states_offset is 1.

class pgmax.fgroup.ANDFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of ANDFactors.


Offset to go from a variable’s relevant state to its other state. For ANDFactors the edge_states_offset is -1.


class pgmax.fgroup.PoolFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of PoolFactors.


A sub-package containing functions to perform belief propagation.


A module containing the core message-passing functions for belief propagation.


A module defining container classes for belief propagation states.


A module solving the smoothed dual of the LP relaxation of the MAP problem.


Compute the energy of a MAP decoding.


Shared context classes for the inference methods.


class pgmax.infer.BeliefPropagation(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays], run_bp: Callable[[], pgmax.infer.bp_state.BPArrays], run_with_diffs: Callable[[], Tuple[pgmax.infer.bp_state.BPArrays, jax.Array]])

Belief propagation functions.


Backward compatible version of run.


Run inference while monitoring convergence

pgmax.infer.BP(bp_state: pgmax.infer.bp_state.BPState, temperature: Optional[float, None] = 0.0) → pgmax.infer.bp.BeliefPropagation

Returns the generated belief propagation functions.

  • bp_state – Belief propagation state.

  • temperature – Temperature for loopy belief propagation. 1.0 corresponds to sum-product, 0.0 corresponds to max-product. Used for backward compatibility

pgmax.infer.get_marginals(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]

Returns the normalized beliefs of several VarGroups, so that they form a valid probability distribution.

When the temperature is equal to 1.0, get_marginals returns the sum-product estimate of the marginal probabilities.

When the temperature is equal to 0.0, get_marginals returns the max-product estimate of the normalized max-marginal probabilities, defined as: norm_max_marginals(x_i^*) ∝ max_{x: x_i = x_i^*} p(x)

When the temperature is strictly between 0.0 and 1.0, get_marginals returns the belief propagation estimate of the normalized soft max-marginal probabilities, defined as: norm_soft_max_marginals(x_i^*) ∝ (sum_{x: x_i = x_i^*} p(x)^{1 /Temp})^Temp


beliefs – A dictionary containing the beliefs of the VarGroups.


class pgmax.infer.BPArrays(log_potentials: Union[numpy.ndarray, jax.Array], ftov_msgs: Union[numpy.ndarray, jax.Array], evidence: Union[numpy.ndarray, jax.Array])

Container for the relevant flat arrays used in belief propagation.


Flat log potentials array.


Flat factor to variable messages array.


Flat evidence array.

class pgmax.infer.LogPotentials(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating log potentials.


Factor graph state


Optionally specify an initial value

Raises: ValueError if provided value shape does not match the expected

log_potentials shape.

class pgmax.infer.FToVMessages(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating factor to variable messages.


Factor graph state


Optionally specify initial value for ftov messages

Raises: ValueError if provided value does not match expected ftov messages


class pgmax.infer.Evidence(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating evidence.


Factor graph state


Optionally specify initial value for evidence

Raises: ValueError if provided value does not match expected evidence shape.

class pgmax.infer.BPState(log_potentials: pgmax.infer.bp_state.LogPotentials, ftov_msgs: pgmax.infer.bp_state.FToVMessages, evidence: pgmax.infer.bp_state.Evidence)

Container class for belief propagation states, including log potentials, ftov messages and evidence (unary log potentials).


log potentials of the model


factor to variable messages


evidence (unary log potentials) for variables.


associated factor graph state

Raises: ValueError if log_potentials, ftov_msgs or evidence are not derived

from the same Factor graph state.


class pgmax.infer.SmoothDualLP(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays], run_with_objvals: Callable[[], Tuple[pgmax.infer.bp_state.BPArrays, float]], decode_primal_unaries: Callable[[], Tuple[Dict[Hashable, Any], Dict[Hashable, Any]]], get_primal_upper_bound: Callable[[], float], get_map_lower_bound: Callable[[], float], get_bp_updates: Callable[[], Tuple[jax.Array, jax.Array, jax.Array]])

Smooth Dual LP-MAP solver functions.


Solves the Smooth Dual LP-MAP problem via accelerated gradient descent (or subgradient descent) and returns the objective value at each step.


Decodes the primal LP-MAP unaries and returns a state assignment for each variable of the FactorGraph.


Returns an upper bound of the optimal objective value of the (non smooth) LP-MAP problem.


Returns a lower bound of the optimal objective value of the (Integer Programming) MAP problem.


Used for unit test. Get the BP updates involved in the SDLP solver.

pgmax.infer.SDLP(bp_state: pgmax.infer.bp_state.BPState) → pgmax.infer.dual_lp.SmoothDualLP

Returns the generated Smooth Dual LP-MAP functions.


bp_state – Belief propagation state.


pgmax.infer.compute_energy(bp_state: pgmax.infer.bp_state.BPState, bp_arrays: pgmax.infer.bp_state.BPArrays, map_states: Dict[Hashable, Any], debug_mode=False) → Tuple[float, Any, Any]

Return the energy of a decoding, expressed by its MAP states.

  • bp_state – Belief propagation state

  • bp_arrays – Arrays of log_potentials, ftov_msgs, evidence

  • map_states – A dictionary mapping the VarGroups of the FactorGraph to their MAP states

  • debug_mode – Debug mode returns the individual energies of each variable and factor in the FactorGraph


The energy of the decoding vars_energies: The energy of each individual variable (only in debug mode) factors_energies: The energy of each individual factor (only in debug mode)

Return type:


Note: Remember that the lower the energy, the better the decoding!


pgmax.infer.decode_map_states(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]

Returns the MAP states of several VarGroups given their beliefs.


beliefs – A dictionary containing the beliefs of the VarGroups.

class pgmax.infer.Inferer(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays])

Inferer pure functions.


Function to create log_potentials, ftov_msgs and evidence.


Function to update log_potentials, ftov_msgs and evidence.


Function to reconstruct the BPState from a BPArrays.


Function to calculate beliefs from a BPArrays.


Function to run inference.

class pgmax.infer.InfererContext(bp_state: pgmax.infer.bp_state.BPState)

Shared inference context for the different inferers.


Belief propagation state.


A sub-package defining different types of groups of variables.


A module containing the base class for variable groups in a Factor Graph.


A module containing a subclass of VarGroup for n-dimensional grids of variables.


A module containing a variable dictionnary class inheriting from the base VarGroup.


class pgmax.vgroup.VarGroup(num_states: Union[int, numpy.ndarray])

Class to represent a group of variables.

Each variable is represented via a tuple of the form (variable hash, variable num_states).


An integer or an array specifying the number of states of the variables in this VarGroup


Hash value


class pgmax.vgroup.NDVarArray(num_states: Union[int, numpy.ndarray], shape: Tuple[int, ])

Subclass of VarGroup for n-dimensional grids of variables.


An integer or an array specifying the number of states of the variables in this VarGroup


Tuple specifying the size of each dimension of the grid (similar to the notion of a NumPy ndarray shape)


class pgmax.vgroup.VarDict(num_states: Union[int, numpy.ndarray], variable_names: Tuple[Any, ])

A variable dictionary that contains a set of variables.


The size of the variables in this VarGroup


A tuple of all the names of the variables in this VarGroup.


A module containing helper functions.