Machine Learning — Variational Inference
Bayes’ Theorem looks naively simple. But, the denominator is the partition function that integrates over z. In general, it cannot be solved analytically. Even if we can model the prior and likelihood with known families of distribution, the posterior p(z|x) remains intractable in general.
Let’s demonstrate the complexity with a simple example. We use a multinomial distribution to select one of the K normal distributions. Then we sample xᵢ with the selected normal distribution. As shown, the posterior’s complexity is already not manageable.
An alternative is to approximate the solution instead. In ML, there are two major approximation approaches. They are sampling and variational inference. In this article, we will discuss the latter.
In variational inferencing, given the observation X, we build a probability model q for latent variables z, i.e. q ≈p(z|X).
The marginal p(X) above can be computed as:
In variational inference, we avoid computing the marginal p(X). This partition function is usually nasty. Instead, we select some tractable families of distribution q to approximate p.
We fit q with sample data to learn the distribution parameters θ. When we make our choice for q, we make sure it is easy to manipulate. For example, its expectation and the normalization factor can be computed directly from the distribution parameters. Because of this choice, we can use q in place of p to make any inference or analysis.
Overview
While the concept sounds simple, the details are not. In this section, we will elaborate on the major steps with the famous topic modeling algorithm called Latent Dirichlet Allocation (LDA). We hope that this gives you a top-level overview before digging into the details and the proves.
The following is the Graphical model for LDA.
This model contains variables α, β, θ, z, and w. Don’t worry about the meanings of the variables since it is not important in our context. w is our observations. θ and z are the hidden variables (latent factors) that we want to discover. α and β are fixed and known in our discussion. The arrows in the Graphical model indicate dependence. For example, w depends on z and β only. Hence, p(w|α, β, θ, z) can be simplified to p(w|z, β).
Like many probability models, we are interested in modeling the joint distribution p(w, θ, z |α, β) given the known input. We apply the chain rule to expand the joint probability such that it composes of distributions of single variables only. Then, we apply the dependency from the graph to simplify each term. We get
Based on the topic modeling problem, θ and w can be modeled with Dirichlet distributions and z with a multinomial distribution. Our objective is to approximate p with q for all the hidden variables θ and z.
We define an objective to minimize the difference between p and q. This can be done by maximizing ELBO (Evidence Lower Bound) below.
Even it is not that obvious, ELBO is maximized when p and q are the same. However, the joint probability q(θ, z) is still too hard to model. We will break it up and approximate it as q(θ, z) ≈ q(θ) q(z). Even it may not be perfect, the empirical result is usually good. z composes of multiple variables z₁, z₂, z₃, ... and can be decomposed into individual components as q(z₁)q(z₂)... Therefore, the final model for q is:
According to the topic modeling problem, we can model θ with a Dirichlet distribution and zᵢ with a multinomial distribution with γ and φᵢ for the corresponding distribution parameters. This is a great milestone since we manage to model a complex model with distributions of individual hidden variables and select a tractable distribution for each hidden variable. The remaining question is how to learn γ and φᵢ. Let’s get back to the ELBO objective:
In many ML problems, to model the problem effectively, hidden variables often depend on each other. We cannot optimize them in a single step. Instead, we optimize one variable at a time while holding other variables fixed. So, we rotate the hidden variables to be optimized in alternating steps until the solution converges. In LDA, z and θ are optimized in steps 5 and 6 below separately.
The remaining major question left is how to optimize a variational parameter while fixing others. In each iteration, the optimal distribution for the targeted hidden variable zk is:
The numerator integrates over all the hidden variables except zk.
It sounds like we are reintroducing the evil twins back: the normalization factor. Nevertheless, it will not be a problem. We choose q to be tractable distributions. Their expectation and the normalization can be derived from the distribution parameters analytically.
The numerator in the equation deserves more explanation. For a regular expectation E[f(x₁, x₂, x₃)], we evaluate f over all variables.
But for our numerator, we omit the targeted variable.
i.e.,
The -k is short for:
However, we will not perform the integration in computing the expectation. Our choice of qᵢ allows us to simplify many calculations in the maximization of ELBO. Let’s detail it more.
In LDA, q is approximated as:
where θ and z are modeled by γ and 𝜙 respectively. Our calculation involves:
- Expand the ELBO into terms of individual variables
- Compute the expected value
- Optimize the ELBO
Expand ELBO
Using the Graphical model and the chain rule, we expand the ELBO as:
Compute the expected value
We don’t want to overwhelm you with details. Therefore, we only demonstrate how to compute the expectation for the first term only. First, θ is modeled by a Dirichlet distribution with parameter α.
Next, we will compute its expectation w.r.t. q.
Without proof here, E[log θᵢ] can be calculated directly from γ.
We choose q thoughtfully, usually with well-known distributions based on the property of the hidden variables in the problem statement. Mathematicians already solve those expectation expressions analytically. We don’t even worry about the normalization factor.
Optimize ELBO
After we expand all the remaining terms in ELBO, we can differentiate it w.r.t. γᵢ (the ith parameter in γ) and 𝜙nᵢ (the ith parameter in nth word). By setting the derivative as zero, we find the optimal solution for γᵢ as:
And the optimal solution for 𝜙nᵢ will be:
Because of the dependence between γ and 𝜙nᵢ, we will optimize the parameters iteratively in alternating steps.
Here is the overview. For the remaining of articles, we will cover some major design decisions in variational inference, proofs, and a detailed example.
KL-divergence
To find q, we turn the problem into an optimization problem. We compute the optimal parameters for q that minimizes the reverse KL-divergence for the target p*.
As shown before, KL-divergence is not symmetrical. The optimal solutions for q will only be the same for KL(p, q) and KL(q,p) when q is complex enough to model p. This raises an important question of why reverse KL-divergence KL(q,p) is used when KL-divergence KL(p, q) matches the expectation of p better. For example, when using a Gaussian distribution to model a bimodal distribution in blue, the reverse KL-divergence solutions will be either the red curve in the diagram (b) or (c). Both solutions cover one mode only.
However, the KL-divergence solution in (a) will cover most of the original distribution and its mean will match the mean of p*.
Moments, including the mean and the variance, describes a distribution. The KL-divergence solution is a moment projection (m-projection). It matches q with the moments of p. If we match all the moment parameters, they will be exactly the same. If a family of the exponential distribution is used for q, we can use KL-divergence to match the moments of q with p* exactly. Without much explanation here, their expected sufficient statistics will match.
(i.e. p=q) The reverse KL-divergence is an information projection (i-projection) which does not necessarily yield the right moments. Judged from this, we may conclude m-projection is superior. However, if a mechanism can match p* exactly, such a mechanism needs to understand p* fully too which is hard in the first place. So it does not sound as good as it may be.
In variational inference, i-projection is used instead. To justify our choice, let’s bring up a couple of constraints that we want to follow. First, we want to avoid the computation of the partition function, the calculation is hard. Second, we want to avoid computing p(z) since we need the partition function to compute it. So let’s define a new term for p, the unnormalized distribution, that separate the partition function out.
Let’s plug the new definition into the reverse KL-divergence.
Z does not vary w.r.t. q. It can be ignored when we minimize reverse KL-divergence.
This is great news. In the Graphical model, the un-normalized p are well-defined using factors. They are easy to compute and the objective in the R.H.S. does not need any normalization. Using the reverse KL-divergence is a good compromise even it may not be perfect under certain scenarios. For q is overly simple compared with p*, the result may hurt. However, variation inference usually demonstrates good empirical results. Next, let’s see how to optimize the reverse KL-divergence.
Evidence lower bound (ELBO)
Let’s introduce Jensen’s inequality below for a convex function f and a term called evidence lower bound (ELBO)
ELBO is literally the lower bound of the evidence (log p(x)) after applying Jensen’s inequality for a concave function in the last step.
And ELBO is related to the KL-divergence as:
Let Z be the marginal p(x) for now. Don’t confuse Z with the hidden variables z. Unfortunately, we need to overload the notation with a capital letter as Z is often used in other literature.
Z does not change on how we model q. So from the perspective of optimizing q, log Z is a constant.
Therefore, minimizing the KL-divergence will be the same as maximizing ELBO. Intuitively, given any distribution q, ELBO is always the lower bound for log Z. However, when q equals p*, the gap diminishes to zero. Therefore, maximizing ELBO reduce the KL-divergence to zero.
By maximizing the evidence lower bound ELBO, we minimize the difference of two data distributions.
Let’s generalize the ELBO as
where Z is now a general normalization factor.
Again, as shown above, maximizing ELBO is the same as minimizing KL-divergence as Z does not vary on how we model q.
This brings a major advantage over the KL-divergence. ELBO works well for both normalized and unnormalized distribution and no need to calculate Z which is required for the regular KL-divergence definition.
ELBO and the Graphical model (optional)
Let’s demonstrate how the unnormalized distribution is computed in the ELBO using the Graphical model. The joint probability distribution can be modeled by the Markov Random Field as:
We substitute the unnormalized p in ELBO with the factors 𝜙 above.
Therefore, minimizing the KL-divergence is equivalent to minimize the Gibbs free energy. We call it free energy because it is the part of the energy that we can manipulate by changing the configuration. This model can be further expanded if we expand the model using an energy model.
Mean Field Variational Inference
(Credit: the proof and the equations are originated from here.)
Don’t get happy too fast. We have missed an important and difficult step in the variational inference. What is the choice of q? It can be extremely hard when q contains multiple variables, i.e. q(z) = q(z₁, z₂, z₃, …). To further reduce the complexity, the mean field variational inference makes a bold assumption that the distribution can be broken down into distributions each involves one hidden variable only.
Then, we model each distribution with a tractable distribution based on the problem. Our choice of distribution will be easy to analyze analytically. For example, if z₁ is multinomial, we model it with a multinomial distribution. As mentioned before, many hidden variables depend on each other. So we are going to use coordinate descent to optimize it. We group hidden variables into groups each containing independent variables. We rotate and optimize each group of variables alternatively until the solution converges.
So the last difficult question is how to optimize qᵢ(zᵢ) in each iteration step. We will introduce a few concepts first. The chain rule on probability can be written as the following when x does not depend on z:
Second, since we model q(z) into independent components qᵢ(zᵢ), we can model the entropy as the sum of individual entropy.
With this information, we expand the ELBO
into
The ordering of zⱼ in z is very arbitrary. In the following equation, we make zk to be the last element. and group everything unrelated to z into a constant. Therefore, the equation becomes
We further remove terms that are unrelated to zk and then express it in the integral form.
We take the derivative and set it to zero to find the optimized distribution q(zk).
The optimal solution is
with all the constant absorbed and transformed into Z’. We can expand the numerator with Baye’s theorem. Again, the corresponding denominator will be unrelated with zk and therefore absorbed as a normalization factor.
That is the same equation we got in the overview section.
There are other methods in finding the optimized q. Let’s put everything in the context for the MRF. As described before, our objective is
Let’s expand it with q(x) to be q(x₁) q(x₂) q(x₃) …
This equation can be solved with linear algebra similar to the MAP inference. But we will not detail the solution here.
Recap
We know the equation for the distribution p. But it is nasty to analyze or to manipulate it.
So given the observation, we are going to model p with tractable qᵢ for each individual model parameter. For example,
To minimize the difference between p and q, we maximize the ELBO below.
In each iteration step, the optimal solution for the corresponding model parameter zⱼ will be:
Since each q is chosen to be tractable, finding the expectation value or normalization factor (if needed) can be done analytically and pretty straightforward.
Example
(Credit: the example and some equations are originated from here.)
Let’s demonstrate the variation inference with an example. Consider the distribution p(x) below:
where μ (mean) and τ (precision) are modeled by Gaussian and Gamma distribution respectively. So let's approximate p(x, μ, τ) with q(μ, τ). With variance inference, we can learn both parameters from the data. The optimal value for μ and τ in each iteration will satisfy
Therefore, let’s evaluate p(x, μ, τ) by expanding it with the chain rule first and then the definition of p from the problem definition.
Our next task is to approximate p by q using the mean field variational inference below.
Now, applying the mean field variation inference, we get:
The log q is quadratic. So q is Gaussian distributed.
Our next task is matching the equation above with the Gaussian definition to find the parameter μ and τ (τ ⁻¹ = σ²).
Therefore, μ and τ are:
As mentioned, computing the normalization Z is hard in general, but not for these well-known distributions. The normalization factor can be computed by the distribution parameters if needed. We need to focus on finding these parameters instead.
We repeat the same process in computing log q(τ).
τ is Gamma distributed because the distribution above depends on τ and log τ only. The corresponding parameter a and b for the Gamma distribution is:
Now, we have two tractable distributions and we want to find their parameters μ and τ.
Again, let’s rewrite some terms into expectation forms.
As promised before, mathematical has already solved these expectation terms analytically. We don’t even bother to compute any normalization factor.
μ and a can be solved immediately. But τ depend on b, and b depend on τ.
So, we are going to solve them iteratively in alternating steps.
- Initialize τn to some arbitrary value.
- Solve bn with the equation above.
- Solve τn with the equation above.
- Repeat the last two steps until the values converge.
Sampling v.s. Variational inference
There is a major shortcoming for sampling methods. We don’t know how far the current sampling solution is from the ground truth. We hope that if we perform enough sampling, the solution is close but there is no quantitative measurement for it. To measure such distance, we need an objective function. Since variational inference is formulated as an optimization problem, we do have certain indications on the progress. However, variational inference approximates the solution rather than finding the exact solution. Indeed, it is unlikely that our solution will be exact.
More readings
Topic modeling is one real-life problem that can be solved with variational inference. For people that want more details: