Machine Learning — Graphical Model Exact inference (Variable elimination, Belief propagation, Junction tree)
We build inference systems to emulate human intelligence. Using the probabilistic model in Machine Learning (ML), we model a problem as the joint probability for the observable and the hidden variables.
We can generalize this model as a joint probability of variables:
Typical inferences will be generalized into three categories:
If e is the set of variables for the observable, p(e) will be the likelihood of the observed data. We can imagine a joint distribution p(x) contains part of the variables below:
And all the familiar inferences below can be categories into these three categories.
We can solve all these inferences through marginal inference
and/or MAP (maximum a posteriori) inference.
As shown, everything starts with the mentioned joint probability. Any query with any subset of variables can be answered. It becomes obvious now why
Probabilistic method models the ML problem as a joint probability distribution.
However, marginal inference and MAP inference are NP-hard in general. Unless we discover a simple structure underneath the data, answering an inference needs to deal with an exponential amount of data as the number of variables increases. Specifically, in the marginal inferencing, we sum up an exponential amount of data to calculate the partition function Z. Even MAP’s calculation does not involve Z, we remain likely to enumerate an exponential amount of possibilities to find the max. In our previous article on the Graphical model, we discover independence among variables to reduce such complexity.
If we discover or make enough bold independence assumptions, like those in the Naive Bayes Algorithm or the Hidden Markov Model (HMM), we can make the model tractable.
For example, the model in Naive Bayes Algorithm is reduced to distributions of single variables and makes MAP inference tractable.
In general, after introducing the Graphical model, we replace the joint distributions with factors in the graph.
All these discussions come to an important conclusion.
In general, inference is NP-hard and difficult to solve.
But if the model’s factors are simple enough, we can find an exact solution for the inference effectively. Nevertheless, if it is not, it remains possible to find a good solution via approximation methods. You will realize soon enough that a huge amount of ML efforts is devoted to addressing all these problems. For the remaining of this article, we will discuss how exact inference can be done providing the model is simple enough. We will address the approximation solutions later in multiple articles.
However, in many ML problems, we often interest in a very narrow scope of queries in the form of the conditional distribution. For example, in a classification problem, we are interested in P(y|x) but not other possible queries. To address this type of problem, we model the query directly using the Conditional Random Fields (CRF) like the one below.
Our conditional probability is factored just like other Graphical models:
Then, we can solve the problem with the Graphical model methods. The following is the MAP inference for the CRF above.
Variable elimination
Variable elimination solves our inference query exactly.
The concept of variable elimination is to eliminate one variable at a time from the marginal distribution expression.
For simplification, we assume each variable has k unique values. The most naive way to find p(l) is summing over all other kᴺ⁻¹possibilities of N-1 variables p(l,*, *, *, *). But, to reduce the number of operations, we can also rearrange the order of the summation and make them as close to the related probability distribution.
So start from the right end, we can marginalize one variable at a time. In each step below, we create a new factor τ in representing the marginalized factors. τ(a, b, c) represents the intermediate factor that depends on a, b, and c only. If we select the order of the elimination smartly, we can reduce the computation complexity significantly. For example, the complexity of the joint probability below is reduced to O(k³) instead of O(k⁴) in the brute force approach.
Let’s us generalize the concept to include MRF. For simplicity, we will ignore the partition function Z in our discussion since it can be calculated or eliminated with further variable elimination.
Consider the joint probability:
The variable elimination algorithm performs the following steps using an ordered list of variables to be eliminated.
- Multiply all factors Φ containing a variable say b
- Marginalize b to obtain a new factor τ
- Replace these factors Φ with the new intermediate factor τ
Example:
Since we multiplying the factors and sum over a variable, this is also called the sum-product inference. Here is an example of how to marginalize a variable.
Choosing the order of variable elimination is an NP-hard problem. But for some graph, the optimal or near-optimal order can be obvious. In general, we can follow the guidelines below:
- Choose a variable with the fewest dependent variables, i.e. min-neighbors — the vertex with the fewest neighbors or dependence.
- Choose the smaller sizer for the merged factor.
- Minimize the number of fill edges (discussed later) that need to be added to the graph when eliminating a variable.
Example
(Credit: example adopted from here.)
To compute P(A|H=h), starting with the elimination list {H, G, F, E, D, C, B}, we eliminate one variable from the BN according to the order in the list.
Graph transformation
Let’s visualize how the variable elimination is done. First, we moralize the BN to MRF.
Then, we remove the corresponding node for each eliminated variable.
When we remove a node, we make sure its previous neighbors are connected so that the joint probability is modeled correctly with the new factor m. Therefore, in the top right below, before F is removed, we add a new edge (a fill edge) between A and E.
In each step of the elimination, we can visualize it as eliminating one node from a clique (linked with red lines) and replace the related factors in the joint distribution with the new factor m.
The concept in calculating P(A) can be visualized as message passing between cliques inside a directed tree. For example, the message me is computed by marginalizing the multiple of P(E|C, D) with its sub-tree messages — mg and mf.
The complexity of the variable elimination is in the order of the cardinality of the largest clique.
For a clique with 3 nodes each having k unique values, the complexity is O(k³).
So far, our solution addresses one query only, P(A). For other queries, like P(D), do we need to recompute everything again? Fortunately, we may already realize that the variable elimination is simply messages passing inside a directed clique tree. The intermediate message m is reusable in other queries which leads us to the next topic: belief propagation.
Message-passing (belief propagation)
Let’s simplify our discussion to a tree of nodes, instead of a graph.
To compute P(a), we make A to be the root of the tree and pass messages from the leaf to the top. The message from node i to j above will be the factors Σ 𝜙(xᵢ) × 𝜙(xᵢ, xⱼ) × all messages from the children of i. This message is called the belief of j from i. So when a node has received all the messages from its children, it can propagate its own message to its parent.
To compute any queries on other nodes, we compute all messages in the reverse direction also. For example, if we have computed m₁₂, we will also compute m₂₁. So for a tree with n edges, we need to compute 2n messages.
Once the node i below receives all the messages from its neighbors except j, it can propagate its message to j.
Once, it is done, any query is proportional to the total messages coming into that node.
To get the actual normalized probability, we divide it with the partition function Z which can be computed by marginalizing over xᵢ.
Junction tree
Previously, we solve the inference problem with the assumption that the graph is a tree. Now, we are extending it with the junction tree method to handle a graph. The general concept is to apply divide-and-conquer. The graph may not be a tree but we may partition the graph to a tree-like form containing clusters holding variables. For example, the graph G below can be partitioned into a tree containing four clusters, in circles on the right below.
The variables within the cluster can be highly coupled. With the hope that the graph is now divided into much smaller clusters and therefore, the calculation within a cluster is manageable and can be solved exactly. Interaction between clusters follows a tree structure. Once each cluster is solved, we can solve the problem globally as message passing in the belief propagation algorithm.
To be qualified as a junction tree,
- There exists only one single path between each pair of clusters,
- A clique in G must belong to a cluster, and
- For a pair of clusters containing node i, each cluster between the path of these two clusters must also contain node i.
Building a Junction tree
Before building a junction tree, if a graph is a directed graph, we have to moralize it first. Then, the graph needs to be triangularized and become chordless.
A graph is not chordless if a cycle can be found with four or more nodes and yet not all these nodes are fully connected to each other. For example, on the top right above, the think line forms a cycle but we need to add two more links to turn it into chordless. Without proving here, all triangulated graphs have a junction tree.
Start with an order of variable elimination, the junction tree is built as:
After step 1, we only keep the maximum cliques and therefore, the clique (a, b) is thrown away because it is a subset of the larger clique (a, b, c). In step 2, we build the junction-tree with the clique above and the maximum-weight spanning tree.
The weight of the edge between the two clusters is the number of common variables between these two clusters. For example, purple and blue clusters have two common variables and therefore the weight is two. Finally, we compute the maximum-weight spanning tree (the thicker line) and use it to be the junction tree.
Here is an example of how we marginalize x₆ and create a new message m(x₂, x₃, x₅).
The message from cluster i to cluster j is calculated by marginalizing variables in cᵢ that are not in cⱼ with the messages from the neighboring clusters.
The belief of the cluster c based on the message it received is
Therefore, the marginal probability for x can be computed by marginalized over other variables in the cluster.
Loopy belief propagation
The complexity of making inference using the junction tree is in the order of the cardinality of the largest cluster. For some graphical model, this cluster can be huge in particular, when the graph is loopy (containing many non-localized cycles). For an ising model with N × N grids, the width of the junction tree is N. The corresponding clique can have 2ᴺ entries.
But sometimes, we may settle for a good enough solution rather than an exact solution. In the belief propagation, a message is propagated from node i to j when node i receives all the messages from its neighbors other than node j. In the loopy belief propagation, all messages are typically initialized uniformly at the beginning. Then we have an ordered list containing all edges. Going through the list one-by-one (say node i → node j), we propagate the message from i to j using the following equation.
This is the same equation for the belief propagation which it will wait for the neighboring messages to be finalized before sending its own. But in loopy belief propagation, it will not wait. It uses the message value of the current time step (starts with uniform initialization at the beginning). In the loopy belief propagation, we go through the list of edges and continue to iterate the messages. Messages may circulate indefinitely. We perform this iteration for a fixed number of steps or until it converges. In practice, it often works surprisingly well.
Maximum a posteriori — MAP inference
For the last few sections, we focus on marginal inference. Now, we turn into the second category — the MAP inference.
Max-product inference (Viterbi algorithm)
It turns out we can use the concept in the belief propagation to solve MAP inference. Instead of summing over the factors in the marginal inference, we find the maximum of the related factors. i.e. instead of a sum-product inference, MAP inference is a max-product inference. The message between two nodes will be
And the optimal value xᵢ* will be
So the mechanism will be almost the same as the sum-product inference with the exception that we replace the summation in the marginalization with the maximum function. The potential 𝜙 can also express in the energy model with parameter θ. The corresponding MAP inference becomes a max-sum inference instead of a max-product inference.
Exact solutions & Approximation methods
In inference, if the model is not too complex, we have illustrates the following methods in finding an exact solution.
- Variable elimination algorithm
- Message-passing algorithm (belief propagation — sum-product inference for marginal distribution or max-product inference for MAP)
- The junction tree algorithms
But exact solutions can be hard. We may fall back to approximation methods in solving our problems. They may include
- Loopy belief propagation
- Sampling method
- Variational inference
For the next couple of articles, we will explore the approximation methods.
Credits and references
Credits and references are listed in the first article of the Graphical model series.