PGMax API reference

This page contains the list of project’s modules

Factor

A sub-package defining different types of factors.

factor

A module containing the base classes for factors in a factor graph.

enum

Defines an enumeration factor.

logical

Defines a logical factor.

pool

Defines a pool factor.

factor

class pgmax.factor.Wiring(var_states_for_edges: Union[numpy.ndarray, jax.Array])

Wiring for factors.

var_states_for_edges

Array of shape (num_edge_states, 3) For each edge state: var_states_for_edges[ii, 0] contains its global variable state index var_states_for_edges[ii, 1] contains its global edge index var_states_for_edges[ii, 2] contains its global factor index

class pgmax.factor.Factor(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray)

A factor.

variables

List of variables connected by the Factor. Each variable is represented by a tuple (variable hash, variable num_states)

log_potentials

Array of log potentials

Raises:

NotImplementedError – If compile_wiring is not implemented

enum

class pgmax.factor.EnumWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], factor_configs_edge_states: Union[numpy.ndarray, jax.Array])

Wiring for EnumFactors.

factor_configs_edge_states

Array of shape (num_factor_configs, 2) factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices factor_configs_edge_states[ii, 0] contains the global EnumFactor config index factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index Both indices only take into account the EnumFactors of the FactorGraph

num_val_configs

Number of valid configurations for this wiring

num_factors

Number of factors covered by this wiring

class pgmax.factor.EnumFactor(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray, factor_configs: numpy.ndarray)

An enumeration factor.

factor_configs

Array of shape (num_val_configs, num_variables) An array containing an explicit enumeration of all valid configurations

log_potentials

Array of shape (num_val_configs,) An array containing the log of the potential value for each valid configuration

Raises:

ValueError – If: (1) The dtype of the factor_configs array is not int (2) The dtype of the potential array is not float (3) factor_configs does not have the correct shape (4) The potential array does not have the correct shape (5) The factor_configs array contains invalid values

logical

class pgmax.factor.LogicalWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], parents_edge_states: Union[numpy.ndarray, jax.Array], children_edge_states: Union[numpy.ndarray, jax.Array], edge_states_offset: int)

Wiring for LogicalFactors.

parents_edge_states

Array of shape (num_parents, 2) parents_edge_states[ii, 0] contains the global LogicalFactor index parents_edge_states[ii, 1] contains the message index of the parent variable’s relevant state The message index of the parent variable’s other state is parents_edge_states[ii, 1] + edge_states_offset Both indices only take into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph

children_edge_states

Array of shape (num_factors,) children_edge_states[ii] contains the message index of the child variable’s relevant state The message index of the child variable’s other state is children_edge_states[ii] + edge_states_offset Only takes into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph

edge_states_offset

Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1

Raises:

ValueError – If: (1) The are no num_logical_factors different factor indices (2) There is a factor index higher than num_logical_factors - 1 (3) The edge_states_offset is not 1 or -1

class pgmax.factor.LogicalFactor(variables: List[Tuple[int, int]])

A logical OR/AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.

edge_states_offset

Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1

Raises:

ValueError – If: (1) There are less than 2 variables (2) The variables are not all binary

class pgmax.factor.ORFactor(variables: List[Tuple[int, int]])

An OR factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.

An OR factor is defined as: F(p1, p2, …, pn, c) = 0 <=> c = OR(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.

edge_states_offset

Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1.

class pgmax.factor.ANDFactor(variables: List[Tuple[int, int]])

An AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.

An AND factor is defined as:

F(p1, p2, …, pn, c) = 0 <=> c = AND(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.

edge_states_offset

Offset to go from a variable’s relevant state to its other state For ANDFactors the edge_states_offset is -1.

pool

class pgmax.factor.PoolWiring(var_states_for_edges: Union[numpy.ndarray, jax.Array], pool_choices_edge_states: Union[numpy.ndarray, jax.Array], pool_indicators_edge_states: Union[numpy.ndarray, jax.Array])

Wiring for PoolFactors.

pool_choices_edge_states

Array of shape (num_pool_choices, 2) pool_choices_edge_states[ii, 0] contains the global PoolFactor index pool_choices_edge_states[ii, 1] contains the message index of the pool choice variable’s state 0. The message index of the pool choice variable’s state 1 is pool_choices_edge_states[ii, 1] + 1 Both indices only take into account the PoolFactors of the FactorGraph

pool_indicators_edge_states

Array of shape (num_pool_factors,) pool_indicators_edge_states[ii] contains the message index of the pool indicator variable’s state 0 The message index of the pool indicator variable’s state 1 is pool_indicators_edge_states[ii] + 1 Only takes into account the PoolFactors of the FactorGraph

Raises:

ValueError – If: (1) The are no num_pool_factors different factor indices (2) There is a factor index higher than num_pool_factors - 1

class pgmax.factor.PoolFactor(variables: List[Tuple[int, int]])

A Pool factor of the form (pc1, …,pcn, pi) where (pc1,…,pcn) are the pool choices and pi is the pool indicator.

A Pool factor is defined as: F(pc1, …,pcn, pi) = 0 <=> (pc1=…=pcn=pi=0) OR (pi=1 AND pc1 +…+ pcn=1) F(pc1, …,pcn, pi) = -inf o.w.

i.e. either (a) all the variables are set to 0, or (b) the pool indicator variable is set to 1 and exactly one of the pool choices variables is set to 1

Note: placing the pool indicator at the end allows us to reuse our existing infrastucture for wiring logical factors

Fgraph

A sub-package containing functions to represent a factor graph.

fgraph

A module containing the core class to build a factor graph.

fgraph

class pgmax.fgraph.FactorGraphState(variable_groups: Sequence[pgmax.vgroup.vgroup.VarGroup], vars_to_starts: Mapping[Tuple[int, int], int], num_var_states: int, total_factor_num_states: int, factor_type_to_msgs_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_type_to_potentials_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_group_to_potentials_starts: OrderedDict[pgmax.fgroup.fgroup.FactorGroup, int], log_potentials: numpy.ndarray, evidence_to_vars: numpy.ndarray, wiring: OrderedDict[Type[pgmax.factor.factor.Factor], pgmax.factor.factor.Wiring])

FactorGraphState.

variable_groups

VarGroups in the FactorGraph.

vars_to_starts

Maps variables to their starting indices in the flat evidence array. flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] contains evidence to the variable.

num_var_states

Total number of variable states.

total_factor_num_states

Size of the flat ftov messages array.

factor_type_to_msgs_range

Maps factors types to their start and end indices in the flat ftov messages.

factor_type_to_potentials_range

Maps factor types to their start and end indices in the flat log potentials.

factor_group_to_potentials_starts

Maps factor groups to their starting indices in the flat log potentials.

log_potentials

Flat log potentials array concatenated for each factor type.

evidence_to_vars

Maps the evidence entries to their variable indices

wiring

Wiring derived for each factor type.

class pgmax.fgraph.FactorGraph(variable_groups: Union[pgmax.vgroup.vgroup.VarGroup, Sequence[pgmax.vgroup.vgroup.VarGroup]])

Class for representing a factor graph.

Factors in a graph are clustered in factor groups, which are grouped according to their factor types.

variable_groups

A single VarGroup or a list of VarGroups.

Fgroup

A sub-package defining different types of groups of factors.

enum

Defines EnumFactorGroup and PairwiseFactorGroup.

fgroup

A module containing the base classes for factor groups in a factor graph.

logical

Defines LogicalFactorGroup and its two children, ORFactorGroup and ANDFactorGroup.

pool

Defines PoolFactorGroup.

fgroup

class pgmax.fgroup.FactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of Factors.

variables_for_factors

A list of list of variables. Each list within the outer list contains the variables connected to a Factor. The same variable can be connected to multiple Factors.

factor_configs

Optional array containing an explicit enumeration of all valid configurations

log_potentials

Array of log potentials.

factor_type

Factor type shared by all the Factors in the FactorGroup.

Raises:

ValueError – if the FactorGroup does not contain a Factor

class pgmax.fgroup.SingleFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], single_factor: pgmax.factor.factor.Factor)

Class to represent a FactorGroup with a single factor.

For internal use only. Should not be directly used to add FactorGroups to a factor graph.

single_factor

the single factor in the SingleFactorGroup

enum

class pgmax.fgroup.EnumFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], factor_configs: numpy.ndarray, log_potentials: Optional[Union[numpy.ndarray, jax.Array]] = None)

Class to represent a group of EnumFactors.

All factors in the group are assumed to have the same set of valid configurations. The associated log potentials can however be different across factors.

factor_configs

Array of shape (num_val_configs, num_variables) containing explicit enumeration of all valid configurations

log_potentials

Optional 1D array of shape (num_val_configs,) or 2D array of shape (num_factors, num_val_configs). If 1D, the log potentials are copied for each factor of the group. If 2D, it specifices the log potentials of each factor. If None, the log potential are initialized to uniform 0.

factor_type

Factor type shared by all the Factors in the FactorGroup.

Raises:

ValueError if

  1. The specified log_potentials is not of the expected shape. (2) The dtype of the potential array is not float

class pgmax.fgroup.PairwiseFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]], log_potential_matrix: Optional[Union[numpy.ndarray, jax.Array]] = None)

Class to represent a group of EnumFactors where each factor connects to two different variables.

All factors in the group are assumed to be such that all possible configurations of the two variables are valid. The associated log potentials can however be different across factors.

log_potential_matrix

Optional 2D array of shape (num_states1, num_states2) or 3D array of shape (num_factors, num_states1, num_states2) where num_states1 and num_states2 are the number of states of the first and second variables involved in each factor. If 2D, the log potentials are copied for each factor of the group. If 3D, it specifies the log potentials of each factor. If None, the log potential are initialized to uniform 0.

factor_type

Factor type shared by all the Factors in the FactorGroup.

Raises:

ValueError if

  1. The specified log_potential_matrix is not a 2D or 3D array. (2) The dtype of the potential array is not float (3) Some pairwise factors connect to less or more than 2 variables. (4) The specified log_potential_matrix does not match the number of factors. (5) The specified log_potential_matrix does not match the number of variable states of the variables in the factors.

logical

class pgmax.fgroup.LogicalFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of LogicalFactors.

All factors in the group are assumed to have the same edge_states_offset. Consequently, the factors are all ORFactors or ANDFactors.

edge_states_offset

Offset to go from a variable’s relevant state to its other state. For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1

class pgmax.fgroup.ORFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of ORFactors.

edge_states_offset

Offset to go from a variable’s relevant state to its other state. For ORFactors the edge_states_offset is 1.

class pgmax.fgroup.ANDFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of ANDFactors.

edge_states_offset

Offset to go from a variable’s relevant state to its other state. For ANDFactors the edge_states_offset is -1.

pool

class pgmax.fgroup.PoolFactorGroup(variables_for_factors: Sequence[List[Tuple[int, int]]])

Class to represent a group of PoolFactors.

Infer

A sub-package containing functions to perform belief propagation.

bp

A module containing the core message-passing functions for belief propagation.

bp_state

A module defining container classes for belief propagation states.

dual_lp

A module solving the smoothed dual of the LP relaxation of the MAP problem.

energy

Compute the energy of a MAP decoding.

inferer

Shared context classes for the inference methods.

bp

class pgmax.infer.BeliefPropagation(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays], run_bp: Callable[[], pgmax.infer.bp_state.BPArrays], run_with_diffs: Callable[[], Tuple[pgmax.infer.bp_state.BPArrays, jax.Array]])

Belief propagation functions.

run_bp

Backward compatible version of run.

run_with_diffs

Run inference while monitoring convergence

pgmax.infer.BP(bp_state: pgmax.infer.bp_state.BPState, temperature: Optional[float, None] = 0.0) → pgmax.infer.bp.BeliefPropagation

Returns the generated belief propagation functions.

Parameters:
  • bp_state – Belief propagation state.

  • temperature – Temperature for loopy belief propagation. 1.0 corresponds to sum-product, 0.0 corresponds to max-product. Used for backward compatibility

pgmax.infer.get_marginals(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]

Returns the normalized beliefs of several VarGroups, so that they form a valid probability distribution.

When the temperature is equal to 1.0, get_marginals returns the sum-product estimate of the marginal probabilities.

When the temperature is equal to 0.0, get_marginals returns the max-product estimate of the normalized max-marginal probabilities, defined as: norm_max_marginals(x_i^*) ∝ max_{x: x_i = x_i^*} p(x)

When the temperature is strictly between 0.0 and 1.0, get_marginals returns the belief propagation estimate of the normalized soft max-marginal probabilities, defined as: norm_soft_max_marginals(x_i^*) ∝ (sum_{x: x_i = x_i^*} p(x)^{1 /Temp})^Temp

Parameters:

beliefs – A dictionary containing the beliefs of the VarGroups.

bp_state

class pgmax.infer.BPArrays(log_potentials: Union[numpy.ndarray, jax.Array], ftov_msgs: Union[numpy.ndarray, jax.Array], evidence: Union[numpy.ndarray, jax.Array])

Container for the relevant flat arrays used in belief propagation.

log_potentials

Flat log potentials array.

ftov_msgs

Flat factor to variable messages array.

evidence

Flat evidence array.

class pgmax.infer.LogPotentials(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating log potentials.

fg_state

Factor graph state

value

Optionally specify an initial value

Raises: ValueError if provided value shape does not match the expected

log_potentials shape.

class pgmax.infer.FToVMessages(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating factor to variable messages.

fg_state

Factor graph state

value

Optionally specify initial value for ftov messages

Raises: ValueError if provided value does not match expected ftov messages

shape.

class pgmax.infer.Evidence(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)

Class for storing and manipulating evidence.

fg_state

Factor graph state

value

Optionally specify initial value for evidence

Raises: ValueError if provided value does not match expected evidence shape.

class pgmax.infer.BPState(log_potentials: pgmax.infer.bp_state.LogPotentials, ftov_msgs: pgmax.infer.bp_state.FToVMessages, evidence: pgmax.infer.bp_state.Evidence)

Container class for belief propagation states, including log potentials, ftov messages and evidence (unary log potentials).

log_potentials

log potentials of the model

ftov_msgs

factor to variable messages

evidence

evidence (unary log potentials) for variables.

fg_state

associated factor graph state

Raises: ValueError if log_potentials, ftov_msgs or evidence are not derived

from the same Factor graph state.

dual_lp

class pgmax.infer.SmoothDualLP(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays], run_with_objvals: Callable[[], Tuple[pgmax.infer.bp_state.BPArrays, float]], decode_primal_unaries: Callable[[], Tuple[Dict[Hashable, Any], Dict[Hashable, Any]]], get_primal_upper_bound: Callable[[], float], get_map_lower_bound: Callable[[], float], get_bp_updates: Callable[[], Tuple[jax.Array, jax.Array, jax.Array]])

Smooth Dual LP-MAP solver functions.

run_with_objvals

Solves the Smooth Dual LP-MAP problem via accelerated gradient descent (or subgradient descent) and returns the objective value at each step.

decode_primal_unaries

Decodes the primal LP-MAP unaries and returns a state assignment for each variable of the FactorGraph.

get_primal_upper_bound

Returns an upper bound of the optimal objective value of the (non smooth) LP-MAP problem.

get_map_lower_bound

Returns a lower bound of the optimal objective value of the (Integer Programming) MAP problem.

get_bp_updates

Used for unit test. Get the BP updates involved in the SDLP solver.

pgmax.infer.SDLP(bp_state: pgmax.infer.bp_state.BPState) → pgmax.infer.dual_lp.SmoothDualLP

Returns the generated Smooth Dual LP-MAP functions.

Parameters:

bp_state – Belief propagation state.

energy

pgmax.infer.compute_energy(bp_state: pgmax.infer.bp_state.BPState, bp_arrays: pgmax.infer.bp_state.BPArrays, map_states: Dict[Hashable, Any], debug_mode=False) → Tuple[float, Any, Any]

Return the energy of a decoding, expressed by its MAP states.

Parameters:
  • bp_state – Belief propagation state

  • bp_arrays – Arrays of log_potentials, ftov_msgs, evidence

  • map_states – A dictionary mapping the VarGroups of the FactorGraph to their MAP states

  • debug_mode – Debug mode returns the individual energies of each variable and factor in the FactorGraph

Returns:

The energy of the decoding vars_energies: The energy of each individual variable (only in debug mode) factors_energies: The energy of each individual factor (only in debug mode)

Return type:

energy

Note: Remember that the lower the energy, the better the decoding!

inferer

pgmax.infer.decode_map_states(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]

Returns the MAP states of several VarGroups given their beliefs.

Parameters:

beliefs – A dictionary containing the beliefs of the VarGroups.

class pgmax.infer.Inferer(init: Callable[[], pgmax.infer.bp_state.BPArrays], update: Callable[[], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[], Dict[Hashable, Any]], run: Callable[[], pgmax.infer.bp_state.BPArrays])

Inferer pure functions.

init

Function to create log_potentials, ftov_msgs and evidence.

update

Function to update log_potentials, ftov_msgs and evidence.

to_bp_state

Function to reconstruct the BPState from a BPArrays.

get_beliefs

Function to calculate beliefs from a BPArrays.

run

Function to run inference.

class pgmax.infer.InfererContext(bp_state: pgmax.infer.bp_state.BPState)

Shared inference context for the different inferers.

bp_state

Belief propagation state.

Vgroup

A sub-package defining different types of groups of variables.

vgroup

A module containing the base class for variable groups in a Factor Graph.

varray

A module containing a subclass of VarGroup for n-dimensional grids of variables.

vdict

A module containing a variable dictionnary class inheriting from the base VarGroup.

vgroup

class pgmax.vgroup.VarGroup(num_states: Union[int, numpy.ndarray])

Class to represent a group of variables.

Each variable is represented via a tuple of the form (variable hash, variable num_states).

num_states

An integer or an array specifying the number of states of the variables in this VarGroup

_hash

Hash value

varray

class pgmax.vgroup.NDVarArray(num_states: Union[int, numpy.ndarray], shape: Tuple[int, ])

Subclass of VarGroup for n-dimensional grids of variables.

num_states

An integer or an array specifying the number of states of the variables in this VarGroup

shape

Tuple specifying the size of each dimension of the grid (similar to the notion of a NumPy ndarray shape)

vdict

class pgmax.vgroup.VarDict(num_states: Union[int, numpy.ndarray], variable_names: Tuple[Any, ])

A variable dictionary that contains a set of variables.

num_states

The size of the variables in this VarGroup

variable_names

A tuple of all the names of the variables in this VarGroup.

Utils

A module containing helper functions.