PGMax API reference¶
This page contains the list of project’s modules
Factor¶
A sub-package defining different types of factors.
|
A module containing the base classes for factors in a factor graph. |
|
Defines an enumeration factor. |
|
Defines a logical factor. |
|
Defines a pool factor. |
factor¶
-
class
pgmax.factor.
Wiring
(var_states_for_edges: Union[numpy.ndarray, jax.Array])¶ Wiring for factors.
-
var_states_for_edges
¶ Array of shape (num_edge_states, 3) For each edge state: var_states_for_edges[ii, 0] contains its global variable state index var_states_for_edges[ii, 1] contains its global edge index var_states_for_edges[ii, 2] contains its global factor index
-
-
class
pgmax.factor.
Factor
(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray)¶ A factor.
-
variables
¶ List of variables connected by the Factor. Each variable is represented by a tuple (variable hash, variable num_states)
-
log_potentials
¶ Array of log potentials
- Raises:
NotImplementedError – If compile_wiring is not implemented
-
enum¶
-
class
pgmax.factor.
EnumWiring
(var_states_for_edges: Union[numpy.ndarray, jax.Array], factor_configs_edge_states: Union[numpy.ndarray, jax.Array])¶ Wiring for EnumFactors.
-
factor_configs_edge_states
¶ Array of shape (num_factor_configs, 2) factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices factor_configs_edge_states[ii, 0] contains the global EnumFactor config index factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index Both indices only take into account the EnumFactors of the FactorGraph
-
num_val_configs
¶ Number of valid configurations for this wiring
-
num_factors
¶ Number of factors covered by this wiring
-
-
class
pgmax.factor.
EnumFactor
(variables: List[Tuple[int, int]], log_potentials: numpy.ndarray, factor_configs: numpy.ndarray)¶ An enumeration factor.
-
factor_configs
¶ Array of shape (num_val_configs, num_variables) An array containing an explicit enumeration of all valid configurations
-
log_potentials
¶ Array of shape (num_val_configs,) An array containing the log of the potential value for each valid configuration
- Raises:
ValueError – If: (1) The dtype of the factor_configs array is not int (2) The dtype of the potential array is not float (3) factor_configs does not have the correct shape (4) The potential array does not have the correct shape (5) The factor_configs array contains invalid values
-
logical¶
-
class
pgmax.factor.
LogicalWiring
(var_states_for_edges: Union[numpy.ndarray, jax.Array], parents_edge_states: Union[numpy.ndarray, jax.Array], children_edge_states: Union[numpy.ndarray, jax.Array], edge_states_offset: int)¶ Wiring for LogicalFactors.
-
parents_edge_states
¶ Array of shape (num_parents, 2) parents_edge_states[ii, 0] contains the global LogicalFactor index parents_edge_states[ii, 1] contains the message index of the parent variable’s relevant state The message index of the parent variable’s other state is parents_edge_states[ii, 1] + edge_states_offset Both indices only take into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph
-
children_edge_states
¶ Array of shape (num_factors,) children_edge_states[ii] contains the message index of the child variable’s relevant state The message index of the child variable’s other state is children_edge_states[ii] + edge_states_offset Only takes into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph
-
edge_states_offset
¶ Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1
- Raises:
ValueError – If: (1) The are no num_logical_factors different factor indices (2) There is a factor index higher than num_logical_factors - 1 (3) The edge_states_offset is not 1 or -1
-
-
class
pgmax.factor.
LogicalFactor
(variables: List[Tuple[int, int]])¶ A logical OR/AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.
-
edge_states_offset
¶ Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1
- Raises:
ValueError – If: (1) There are less than 2 variables (2) The variables are not all binary
-
-
class
pgmax.factor.
ORFactor
(variables: List[Tuple[int, int]])¶ An OR factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.
An OR factor is defined as: F(p1, p2, …, pn, c) = 0 <=> c = OR(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.
-
edge_states_offset
¶ Offset to go from a variable’s relevant state to its other state For ORFactors the edge_states_offset is 1.
-
-
class
pgmax.factor.
ANDFactor
(variables: List[Tuple[int, int]])¶ An AND factor of the form (p1,…,pn, c) where p1,…,pn are the parents variables and c is the child variable.
- An AND factor is defined as:
F(p1, p2, …, pn, c) = 0 <=> c = AND(p1, p2, …, pn) F(p1, p2, …, pn, c) = -inf o.w.
-
edge_states_offset
¶ Offset to go from a variable’s relevant state to its other state For ANDFactors the edge_states_offset is -1.
pool¶
-
class
pgmax.factor.
PoolWiring
(var_states_for_edges: Union[numpy.ndarray, jax.Array], pool_choices_edge_states: Union[numpy.ndarray, jax.Array], pool_indicators_edge_states: Union[numpy.ndarray, jax.Array])¶ Wiring for PoolFactors.
-
pool_choices_edge_states
¶ Array of shape (num_pool_choices, 2) pool_choices_edge_states[ii, 0] contains the global PoolFactor index pool_choices_edge_states[ii, 1] contains the message index of the pool choice variable’s state 0. The message index of the pool choice variable’s state 1 is pool_choices_edge_states[ii, 1] + 1 Both indices only take into account the PoolFactors of the FactorGraph
-
pool_indicators_edge_states
¶ Array of shape (num_pool_factors,) pool_indicators_edge_states[ii] contains the message index of the pool indicator variable’s state 0 The message index of the pool indicator variable’s state 1 is pool_indicators_edge_states[ii] + 1 Only takes into account the PoolFactors of the FactorGraph
- Raises:
ValueError – If: (1) The are no num_pool_factors different factor indices (2) There is a factor index higher than num_pool_factors - 1
-
-
class
pgmax.factor.
PoolFactor
(variables: List[Tuple[int, int]])¶ A Pool factor of the form (pc1, …,pcn, pi) where (pc1,…,pcn) are the pool choices and pi is the pool indicator.
A Pool factor is defined as: F(pc1, …,pcn, pi) = 0 <=> (pc1=…=pcn=pi=0) OR (pi=1 AND pc1 +…+ pcn=1) F(pc1, …,pcn, pi) = -inf o.w.
i.e. either (a) all the variables are set to 0, or (b) the pool indicator variable is set to 1 and exactly one of the pool choices variables is set to 1
Note: placing the pool indicator at the end allows us to reuse our existing infrastucture for wiring logical factors
Fgraph¶
A sub-package containing functions to represent a factor graph.
|
A module containing the core class to build a factor graph. |
fgraph¶
-
class
pgmax.fgraph.
FactorGraphState
(variable_groups: Sequence[pgmax.vgroup.vgroup.VarGroup], vars_to_starts: Mapping[Tuple[int, int], int], num_var_states: int, total_factor_num_states: int, factor_type_to_msgs_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_type_to_potentials_range: OrderedDict[Type[pgmax.factor.factor.Factor], Tuple[int, int]], factor_group_to_potentials_starts: OrderedDict[pgmax.fgroup.fgroup.FactorGroup, int], log_potentials: numpy.ndarray, evidence_to_vars: numpy.ndarray, wiring: OrderedDict[Type[pgmax.factor.factor.Factor], pgmax.factor.factor.Wiring])¶ FactorGraphState.
-
variable_groups
¶ VarGroups in the FactorGraph.
-
vars_to_starts
¶ Maps variables to their starting indices in the flat evidence array. flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] contains evidence to the variable.
-
num_var_states
¶ Total number of variable states.
-
total_factor_num_states
¶ Size of the flat ftov messages array.
-
factor_type_to_msgs_range
¶ Maps factors types to their start and end indices in the flat ftov messages.
-
factor_type_to_potentials_range
¶ Maps factor types to their start and end indices in the flat log potentials.
-
factor_group_to_potentials_starts
¶ Maps factor groups to their starting indices in the flat log potentials.
-
log_potentials
¶ Flat log potentials array concatenated for each factor type.
-
evidence_to_vars
¶ Maps the evidence entries to their variable indices
-
wiring
¶ Wiring derived for each factor type.
-
-
class
pgmax.fgraph.
FactorGraph
(variable_groups: Union[pgmax.vgroup.vgroup.VarGroup, Sequence[pgmax.vgroup.vgroup.VarGroup]])¶ Class for representing a factor graph.
Factors in a graph are clustered in factor groups, which are grouped according to their factor types.
-
variable_groups
¶ A single VarGroup or a list of VarGroups.
-
Fgroup¶
A sub-package defining different types of groups of factors.
|
Defines EnumFactorGroup and PairwiseFactorGroup. |
|
A module containing the base classes for factor groups in a factor graph. |
|
Defines LogicalFactorGroup and its two children, ORFactorGroup and ANDFactorGroup. |
|
Defines PoolFactorGroup. |
fgroup¶
-
class
pgmax.fgroup.
FactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]])¶ Class to represent a group of Factors.
-
variables_for_factors
¶ A list of list of variables. Each list within the outer list contains the variables connected to a Factor. The same variable can be connected to multiple Factors.
-
factor_configs
¶ Optional array containing an explicit enumeration of all valid configurations
-
log_potentials
¶ Array of log potentials.
-
factor_type
¶ Factor type shared by all the Factors in the FactorGroup.
- Raises:
ValueError – if the FactorGroup does not contain a Factor
-
-
class
pgmax.fgroup.
SingleFactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]], single_factor: pgmax.factor.factor.Factor)¶ Class to represent a FactorGroup with a single factor.
For internal use only. Should not be directly used to add FactorGroups to a factor graph.
-
single_factor
¶ the single factor in the SingleFactorGroup
-
enum¶
-
class
pgmax.fgroup.
EnumFactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]], factor_configs: numpy.ndarray, log_potentials: Optional[Union[numpy.ndarray, jax.Array]] = None)¶ Class to represent a group of EnumFactors.
All factors in the group are assumed to have the same set of valid configurations. The associated log potentials can however be different across factors.
-
factor_configs
¶ Array of shape (num_val_configs, num_variables) containing explicit enumeration of all valid configurations
-
log_potentials
¶ Optional 1D array of shape (num_val_configs,) or 2D array of shape (num_factors, num_val_configs). If 1D, the log potentials are copied for each factor of the group. If 2D, it specifices the log potentials of each factor. If None, the log potential are initialized to uniform 0.
-
factor_type
¶ Factor type shared by all the Factors in the FactorGroup.
- Raises:
ValueError if –
The specified log_potentials is not of the expected shape. (2) The dtype of the potential array is not float
-
-
class
pgmax.fgroup.
PairwiseFactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]], log_potential_matrix: Optional[Union[numpy.ndarray, jax.Array]] = None)¶ Class to represent a group of EnumFactors where each factor connects to two different variables.
All factors in the group are assumed to be such that all possible configurations of the two variables are valid. The associated log potentials can however be different across factors.
-
log_potential_matrix
¶ Optional 2D array of shape (num_states1, num_states2) or 3D array of shape (num_factors, num_states1, num_states2) where num_states1 and num_states2 are the number of states of the first and second variables involved in each factor. If 2D, the log potentials are copied for each factor of the group. If 3D, it specifies the log potentials of each factor. If None, the log potential are initialized to uniform 0.
-
factor_type
¶ Factor type shared by all the Factors in the FactorGroup.
- Raises:
ValueError if –
The specified log_potential_matrix is not a 2D or 3D array. (2) The dtype of the potential array is not float (3) Some pairwise factors connect to less or more than 2 variables. (4) The specified log_potential_matrix does not match the number of factors. (5) The specified log_potential_matrix does not match the number of variable states of the variables in the factors.
-
logical¶
-
class
pgmax.fgroup.
LogicalFactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]])¶ Class to represent a group of LogicalFactors.
All factors in the group are assumed to have the same edge_states_offset. Consequently, the factors are all ORFactors or ANDFactors.
-
edge_states_offset
¶ Offset to go from a variable’s relevant state to its other state. For ORFactors the edge_states_offset is 1 For ANDFactors the edge_states_offset is -1
-
pool¶
-
class
pgmax.fgroup.
PoolFactorGroup
(variables_for_factors: Sequence[List[Tuple[int, int]]])¶ Class to represent a group of PoolFactors.
Infer¶
A sub-package containing functions to perform belief propagation.
|
A module containing the core message-passing functions for belief propagation. |
|
A module defining container classes for belief propagation states. |
|
A module solving the smoothed dual of the LP relaxation of the MAP problem. |
|
Compute the energy of a MAP decoding. |
|
Shared context classes for the inference methods. |
bp¶
-
class
pgmax.infer.
BeliefPropagation
(init: Callable[[…], pgmax.infer.bp_state.BPArrays], update: Callable[[…], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[…], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[…], Dict[Hashable, Any]], run: Callable[[…], pgmax.infer.bp_state.BPArrays], run_bp: Callable[[…], pgmax.infer.bp_state.BPArrays], run_with_diffs: Callable[[…], Tuple[pgmax.infer.bp_state.BPArrays, jax.Array]])¶ Belief propagation functions.
-
run_bp
¶ Backward compatible version of run.
-
run_with_diffs
¶ Run inference while monitoring convergence
-
-
pgmax.infer.
BP
(bp_state: pgmax.infer.bp_state.BPState, temperature: Optional[float, None] = 0.0) → pgmax.infer.bp.BeliefPropagation¶ Returns the generated belief propagation functions.
- Parameters:
bp_state – Belief propagation state.
temperature – Temperature for loopy belief propagation. 1.0 corresponds to sum-product, 0.0 corresponds to max-product. Used for backward compatibility
-
pgmax.infer.
get_marginals
(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]¶ Returns the normalized beliefs of several VarGroups, so that they form a valid probability distribution.
When the temperature is equal to 1.0, get_marginals returns the sum-product estimate of the marginal probabilities.
When the temperature is equal to 0.0, get_marginals returns the max-product estimate of the normalized max-marginal probabilities, defined as: norm_max_marginals(x_i^*) ∝ max_{x: x_i = x_i^*} p(x)
When the temperature is strictly between 0.0 and 1.0, get_marginals returns the belief propagation estimate of the normalized soft max-marginal probabilities, defined as: norm_soft_max_marginals(x_i^*) ∝ (sum_{x: x_i = x_i^*} p(x)^{1 /Temp})^Temp
- Parameters:
beliefs – A dictionary containing the beliefs of the VarGroups.
bp_state¶
-
class
pgmax.infer.
BPArrays
(log_potentials: Union[numpy.ndarray, jax.Array], ftov_msgs: Union[numpy.ndarray, jax.Array], evidence: Union[numpy.ndarray, jax.Array])¶ Container for the relevant flat arrays used in belief propagation.
-
log_potentials
¶ Flat log potentials array.
-
ftov_msgs
¶ Flat factor to variable messages array.
-
evidence
¶ Flat evidence array.
-
-
class
pgmax.infer.
LogPotentials
(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)¶ Class for storing and manipulating log potentials.
-
fg_state
¶ Factor graph state
-
value
¶ Optionally specify an initial value
- Raises: ValueError if provided value shape does not match the expected
log_potentials shape.
-
-
class
pgmax.infer.
FToVMessages
(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)¶ Class for storing and manipulating factor to variable messages.
-
fg_state
¶ Factor graph state
-
value
¶ Optionally specify initial value for ftov messages
- Raises: ValueError if provided value does not match expected ftov messages
shape.
-
-
class
pgmax.infer.
Evidence
(fg_state: pgmax.fgraph.fgraph.FactorGraphState, value: Optional[numpy.ndarray, None] = None)¶ Class for storing and manipulating evidence.
-
fg_state
¶ Factor graph state
-
value
¶ Optionally specify initial value for evidence
Raises: ValueError if provided value does not match expected evidence shape.
-
-
class
pgmax.infer.
BPState
(log_potentials: pgmax.infer.bp_state.LogPotentials, ftov_msgs: pgmax.infer.bp_state.FToVMessages, evidence: pgmax.infer.bp_state.Evidence)¶ Container class for belief propagation states, including log potentials, ftov messages and evidence (unary log potentials).
-
log_potentials
¶ log potentials of the model
-
ftov_msgs
¶ factor to variable messages
-
evidence
¶ evidence (unary log potentials) for variables.
-
fg_state
¶ associated factor graph state
- Raises: ValueError if log_potentials, ftov_msgs or evidence are not derived
from the same Factor graph state.
-
dual_lp¶
-
class
pgmax.infer.
SmoothDualLP
(init: Callable[[…], pgmax.infer.bp_state.BPArrays], update: Callable[[…], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[…], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[…], Dict[Hashable, Any]], run: Callable[[…], pgmax.infer.bp_state.BPArrays], run_with_objvals: Callable[[…], Tuple[pgmax.infer.bp_state.BPArrays, float]], decode_primal_unaries: Callable[[…], Tuple[Dict[Hashable, Any], Dict[Hashable, Any]]], get_primal_upper_bound: Callable[[…], float], get_map_lower_bound: Callable[[…], float], get_bp_updates: Callable[[…], Tuple[jax.Array, jax.Array, jax.Array]])¶ Smooth Dual LP-MAP solver functions.
-
run_with_objvals
¶ Solves the Smooth Dual LP-MAP problem via accelerated gradient descent (or subgradient descent) and returns the objective value at each step.
-
decode_primal_unaries
¶ Decodes the primal LP-MAP unaries and returns a state assignment for each variable of the FactorGraph.
-
get_primal_upper_bound
¶ Returns an upper bound of the optimal objective value of the (non smooth) LP-MAP problem.
-
get_map_lower_bound
¶ Returns a lower bound of the optimal objective value of the (Integer Programming) MAP problem.
-
get_bp_updates
¶ Used for unit test. Get the BP updates involved in the SDLP solver.
-
-
pgmax.infer.
SDLP
(bp_state: pgmax.infer.bp_state.BPState) → pgmax.infer.dual_lp.SmoothDualLP¶ Returns the generated Smooth Dual LP-MAP functions.
- Parameters:
bp_state – Belief propagation state.
energy¶
-
pgmax.infer.
compute_energy
(bp_state: pgmax.infer.bp_state.BPState, bp_arrays: pgmax.infer.bp_state.BPArrays, map_states: Dict[Hashable, Any], debug_mode=False) → Tuple[float, Any, Any]¶ Return the energy of a decoding, expressed by its MAP states.
- Parameters:
bp_state – Belief propagation state
bp_arrays – Arrays of log_potentials, ftov_msgs, evidence
map_states – A dictionary mapping the VarGroups of the FactorGraph to their MAP states
debug_mode – Debug mode returns the individual energies of each variable and factor in the FactorGraph
- Returns:
The energy of the decoding vars_energies: The energy of each individual variable (only in debug mode) factors_energies: The energy of each individual factor (only in debug mode)
- Return type:
energy
Note: Remember that the lower the energy, the better the decoding!
inferer¶
-
pgmax.infer.
decode_map_states
(beliefs: Dict[Hashable, Any]) → Dict[Hashable, Any]¶ Returns the MAP states of several VarGroups given their beliefs.
- Parameters:
beliefs – A dictionary containing the beliefs of the VarGroups.
-
class
pgmax.infer.
Inferer
(init: Callable[[…], pgmax.infer.bp_state.BPArrays], update: Callable[[…], pgmax.infer.bp_state.BPArrays], to_bp_state: Callable[[…], pgmax.infer.bp_state.BPArrays], get_beliefs: Callable[[…], Dict[Hashable, Any]], run: Callable[[…], pgmax.infer.bp_state.BPArrays])¶ Inferer pure functions.
-
init
¶ Function to create log_potentials, ftov_msgs and evidence.
-
update
¶ Function to update log_potentials, ftov_msgs and evidence.
-
to_bp_state
¶ Function to reconstruct the BPState from a BPArrays.
-
get_beliefs
¶ Function to calculate beliefs from a BPArrays.
-
run
¶ Function to run inference.
-
Vgroup¶
A sub-package defining different types of groups of variables.
|
A module containing the base class for variable groups in a Factor Graph. |
|
A module containing a subclass of VarGroup for n-dimensional grids of variables. |
|
A module containing a variable dictionnary class inheriting from the base VarGroup. |
vgroup¶
-
class
pgmax.vgroup.
VarGroup
(num_states: Union[int, numpy.ndarray])¶ Class to represent a group of variables.
Each variable is represented via a tuple of the form (variable hash, variable num_states).
-
num_states
¶ An integer or an array specifying the number of states of the variables in this VarGroup
-
_hash
¶ Hash value
-
varray¶
-
class
pgmax.vgroup.
NDVarArray
(num_states: Union[int, numpy.ndarray], shape: Tuple[int, …])¶ Subclass of VarGroup for n-dimensional grids of variables.
-
num_states
¶ An integer or an array specifying the number of states of the variables in this VarGroup
-
shape
¶ Tuple specifying the size of each dimension of the grid (similar to the notion of a NumPy ndarray shape)
-
vdict¶
Utils¶
A module containing helper functions.