Machine Learning — Graphical Model Exact inference (Variable elimination, Belief propagation, Junction tree)

Image for post
Image for post

We build inference systems to emulate human intelligence. Using the probabilistic model in Machine Learning (ML), we model a problem as the joint probability for the observable and the hidden variables.

Image for post
Image for post

We can generalize this model as a joint probability of variables:

Image for post
Image for post

Typical inferences will be generalized into three categories:

Image for post
Image for post

If e is the set of variables for the observable, p(e) will be the likelihood of the observed data. We can imagine a joint distribution p(x) contains part of the variables below:

Image for post
Image for post

And all the familiar inferences below can be categories into these three categories.

Image for post
Image for post

We can solve all these inferences through marginal inference

and/or MAP (maximum a posteriori) inference.

Image for post

As shown, everything starts with the mentioned joint probability. Any query with any subset of variables can be answered. It becomes obvious now why

Probabilistic method models the ML problem as a joint probability distribution.

However, marginal inference and MAP inference are NP-hard in general. Unless we discover a simple structure underneath the data, answering an inference needs to deal with an exponential amount of data as the number of variables increases. Specifically, in the marginal inferencing, we sum up an exponential amount of data to calculate the partition function Z. Even MAP’s calculation does not involve Z, we remain likely to enumerate an exponential amount of possibilities to find the max. In our previous article on the Graphical model, we discover independence among variables to reduce such complexity.

Image for post
Image for post

If we discover or make enough bold independence assumptions, like those in the Naive Bayes Algorithm or the Hidden Markov Model (HMM), we can make the model tractable.

Image for post
Image for post
HMM

For example, the model in Naive Bayes Algorithm is reduced to distributions of single variables and makes MAP inference tractable.

Image for post
Image for post

In general, after introducing the Graphical model, we replace the joint distributions with factors in the graph.

Image for post
Image for post
Replacing joint probability with a Bayesian Network factors

All these discussions come to an important conclusion.

In general, inference is NP-hard and difficult to solve.

But if the model’s factors are simple enough, we can find an exact solution for the inference effectively. Nevertheless, if it is not, it remains possible to find a good solution via approximation methods. You will realize soon enough that a huge amount of ML efforts is devoted to addressing all these problems. For the remaining of this article, we will discuss how exact inference can be done providing the model is simple enough. We will address the approximation solutions later in multiple articles.

However, in many ML problems, we often interest in a very narrow scope of queries in the form of the conditional distribution. For example, in a classification problem, we are interested in P(y|x) but not other possible queries. To address this type of problem, we model the query directly using the Conditional Random Fields (CRF) like the one below.

Image for post
Image for post

Our conditional probability is factored just like other Graphical models:

Image for post
Image for post

Then, we can solve the problem with the Graphical model methods. The following is the MAP inference for the CRF above.

Image for post
Image for post

Variable elimination

Variable elimination solves our inference query exactly.

Image for post

The concept of variable elimination is to eliminate one variable at a time from the marginal distribution expression.

Image for post
Image for post

For simplification, we assume each variable has k unique values. The most naive way to find p(l) is summing over all other kᴺ⁻¹possibilities of N-1 variables p(l,*, *, *, *). But, to reduce the number of operations, we can also rearrange the order of the summation and make them as close to the related probability distribution.

Image for post
Image for post

So start from the right end, we can marginalize one variable at a time. In each step below, we create a new factor τ in representing the marginalized factors. τ(a, b, c) represents the intermediate factor that depends on a, b, and c only. If we select the order of the elimination smartly, we can reduce the computation complexity significantly. For example, the complexity of the joint probability below is reduced to O() instead of O(k⁴) in the brute force approach.

Image for post
Image for post

Let’s us generalize the concept to include MRF. For simplicity, we will ignore the partition function Z in our discussion since it can be calculated or eliminated with further variable elimination.

Image for post
Image for post

Consider the joint probability:

Image for post
Image for post

The variable elimination algorithm performs the following steps using an ordered list of variables to be eliminated.

  1. Multiply all factors Φ containing a variable say b
  2. Marginalize b to obtain a new factor τ
  3. Replace these factors Φ with the new intermediate factor τ

Example:

Image for post
Image for post
Sum-product inference

Since we multiplying the factors and sum over a variable, this is also called the sum-product inference. Here is an example of how to marginalize a variable.

Image for post
Image for post
Source

Choosing the order of variable elimination is an NP-hard problem. But for some graph, the optimal or near-optimal order can be obvious. In general, we can follow the guidelines below:

  • Choose a variable with the fewest dependent variables, i.e. min-neighbors — the vertex with the fewest neighbors or dependence.
  • Choose the smaller sizer for the merged factor.
Image for post
Image for post
  • Minimize the number of fill edges (discussed later) that need to be added to the graph when eliminating a variable.

Example

(Credit: example adopted from here.)

Image for post
Image for post

To compute P(A|H=h), starting with the elimination list {H, G, F, E, D, C, B}, we eliminate one variable from the BN according to the order in the list.

Image for post
Image for post

Graph transformation

Let’s visualize how the variable elimination is done. First, we moralize the BN to MRF.

Image for post
Image for post

Then, we remove the corresponding node for each eliminated variable.

Image for post
Image for post

When we remove a node, we make sure its previous neighbors are connected so that the joint probability is modeled correctly with the new factor m. Therefore, in the top right below, before F is removed, we add a new edge (a fill edge) between A and E.

Image for post
Image for post

In each step of the elimination, we can visualize it as eliminating one node from a clique (linked with red lines) and replace the related factors in the joint distribution with the new factor m.

Image for post
Image for post

The concept in calculating P(A) can be visualized as message passing between cliques inside a directed tree. For example, the message me is computed by marginalizing the multiple of P(E|C, D) with its sub-tree messages — mg and mf.

Image for post
Image for post

The complexity of the variable elimination is in the order of the cardinality of the largest clique.

For a clique with 3 nodes each having k unique values, the complexity is O(k³).

So far, our solution addresses one query only, P(A). For other queries, like P(D), do we need to recompute everything again? Fortunately, we may already realize that the variable elimination is simply messages passing inside a directed clique tree. The intermediate message m is reusable in other queries which leads us to the next topic: belief propagation.

Message-passing (belief propagation)

Let’s simplify our discussion to a tree of nodes, instead of a graph.

Image for post
Image for post

To compute P(a), we make A to be the root of the tree and pass messages from the leaf to the top. The message from node i to j above will be the factors Σ 𝜙(xᵢ) × 𝜙(xᵢ, xⱼ) × all messages from the children of i. This message is called the belief of j from i. So when a node has received all the messages from its children, it can propagate its own message to its parent.

To compute any queries on other nodes, we compute all messages in the reverse direction also. For example, if we have computed m₁₂, we will also compute m₂₁. So for a tree with n edges, we need to compute 2n messages.

Image for post
Image for post

Once the node i below receives all the messages from its neighbors except j, it can propagate its message to j.

Image for post
Image for post

Once, it is done, any query is proportional to the total messages coming into that node.

Image for post
Image for post

To get the actual normalized probability, we divide it with the partition function Z which can be computed by marginalizing over xᵢ.

Junction tree

Previously, we solve the inference problem with the assumption that the graph is a tree. Now, we are extending it with the junction tree method to handle a graph. The general concept is to apply divide-and-conquer. The graph may not be a tree but we may partition the graph to a tree-like form containing clusters holding variables. For example, the graph G below can be partitioned into a tree containing four clusters, in circles on the right below.

Image for post
Image for post
Source

The variables within the cluster can be highly coupled. With the hope that the graph is now divided into much smaller clusters and therefore, the calculation within a cluster is manageable and can be solved exactly. Interaction between clusters follows a tree structure. Once each cluster is solved, we can solve the problem globally as message passing in the belief propagation algorithm.

To be qualified as a junction tree,

  • There exists only one single path between each pair of clusters,
  • A clique in G must belong to a cluster, and
  • For a pair of clusters containing node i, each cluster between the path of these two clusters must also contain node i.

Building a Junction tree

Before building a junction tree, if a graph is a directed graph, we have to moralize it first. Then, the graph needs to be triangularized and become chordless.

Image for post
Image for post
Modified from source

A graph is not chordless if a cycle can be found with four or more nodes and yet not all these nodes are fully connected to each other. For example, on the top right above, the think line forms a cycle but we need to add two more links to turn it into chordless. Without proving here, all triangulated graphs have a junction tree.

Start with an order of variable elimination, the junction tree is built as:

Image for post
Image for post
Source

After step 1, we only keep the maximum cliques and therefore, the clique (a, b) is thrown away because it is a subset of the larger clique (a, b, c). In step 2, we build the junction-tree with the clique above and the maximum-weight spanning tree.

Image for post
Image for post
Source

The weight of the edge between the two clusters is the number of common variables between these two clusters. For example, purple and blue clusters have two common variables and therefore the weight is two. Finally, we compute the maximum-weight spanning tree (the thicker line) and use it to be the junction tree.

Here is an example of how we marginalize x₆ and create a new message m(x₂, x₃, x₅).

Image for post
Image for post
Modified from source

The message from cluster i to cluster j is calculated by marginalizing variables in cᵢ that are not in cⱼ with the messages from the neighboring clusters.

Image for post
Image for post
Source of the equation

The belief of the cluster c based on the message it received is

Image for post
Image for post

Therefore, the marginal probability for x can be computed by marginalized over other variables in the cluster.

Image for post
Image for post

Loopy belief propagation

The complexity of making inference using the junction tree is in the order of the cardinality of the largest cluster. For some graphical model, this cluster can be huge in particular, when the graph is loopy (containing many non-localized cycles). For an ising model with N × N grids, the width of the junction tree is N. The corresponding clique can have 2ᴺ entries.

Image for post
Ising model

But sometimes, we may settle for a good enough solution rather than an exact solution. In the belief propagation, a message is propagated from node i to j when node i receives all the messages from its neighbors other than node j. In the loopy belief propagation, all messages are typically initialized uniformly at the beginning. Then we have an ordered list containing all edges. Going through the list one-by-one (say node i → node j), we propagate the message from i to j using the following equation.

Image for post
Image for post

This is the same equation for the belief propagation which it will wait for the neighboring messages to be finalized before sending its own. But in loopy belief propagation, it will not wait. It uses the message value of the current time step (starts with uniform initialization at the beginning). In the loopy belief propagation, we go through the list of edges and continue to iterate the messages. Messages may circulate indefinitely. We perform this iteration for a fixed number of steps or until it converges. In practice, it often works surprisingly well.

Maximum a posteriori — MAP inference

For the last few sections, we focus on marginal inference. Now, we turn into the second category — the MAP inference.

Image for post
Image for post
Modified from source

Max-product inference (Viterbi algorithm)

It turns out we can use the concept in the belief propagation to solve MAP inference. Instead of summing over the factors in the marginal inference, we find the maximum of the related factors. i.e. instead of a sum-product inference, MAP inference is a max-product inference. The message between two nodes will be

Image for post
Image for post

And the optimal value xᵢ* will be

Image for post
Image for post

So the mechanism will be almost the same as the sum-product inference with the exception that we replace the summation in the marginalization with the maximum function. The potential 𝜙 can also express in the energy model with parameter θ. The corresponding MAP inference becomes a max-sum inference instead of a max-product inference.

Image for post
Image for post

Exact solutions & Approximation methods

In inference, if the model is not too complex, we have illustrates the following methods in finding an exact solution.

  • Variable elimination algorithm
  • Message-passing algorithm (belief propagation — sum-product inference for marginal distribution or max-product inference for MAP)
  • The junction tree algorithms

But exact solutions can be hard. We may fall back to approximation methods in solving our problems. They may include

  • Loopy belief propagation
  • Sampling method
  • Variational inference

For the next couple of articles, we will explore the approximation methods.

Credits and references

Credits and references are listed in the first article of the Graphical model series.

Written by

Deep Learning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store