Machine Learning — Sampling-based Inference
Google DeepMind’s AlphaGo beats the GO champions by smart sampling. Human solves problems with sampling all the time. Unlike many machine learning (ML) algorithms, we don’t pause to figure out the math first. For example, we make sure we don’t talk to our parents in a bad time, that comes from our sampled experience. Professional gamblers count card by assigning a value to every card and keep a running count of the cards seen so far. They change their bets based on the count number to increase the odds of winning. Sampling is easy and simple. That is attractive. In this article, we will review the Markov Chain Monte Carlo, importance sampling, Metropolis algorithm and Gibbs sampling that are commonly used in ML.
Challenge
The success of these examples depends on a few critical factors. First, how can we create a dense representation of the data learned so far? Even with a terabyte of memory, it is still no match with the combinations of possible scenarios. In AlphaGo, a deep network is used and the gamblers create a running sum to represent the information they need to beat the house. In machine learning ML, we may model the data with a Gaussian distribution containings the means and the covariance, a much denser representation for the information we care. Card counters use the running sum because the model is simple and easy to manipulate. This is an important principle in ML. Choose the model wisely.
But when problems get complex, simple sampling is not enough. Otherwise, everyone can be a GO master. Sampling efficiency is critical for high dimensional search space. The search space is huge. Just some tiny deviations from the optimal solution can lead us from a winner to a total loser. Many collected samples provide little to no additional information. However, a balance of exploration and exploitation is an art. We want to gain the best information with the least effort. However, the best information does not necessarily mean maximizing the objective greedily. Sometimes, we need to make sacrifices for long term gains. The success of AlphaGo contributes to its sampling efficiency, it samples moves smartly and adaptively.
In sample-based inference, we make inferences from sampled data. For example, from the sample collected below, the marginal probability p(Gender=male, Age=24) = 2/18 can be computed as:
Such samples may be collected with interviews from random phone calls. But in ML, we may assume a condensed probability model p is already modeled and trained with expert knowledge and training data.
Once the model is known, we can draw samples from this distribution p to make inferences. However, this is not easy in general. We suggest you pause for a minute and think about how to sample data from distribution as simple as a standard normal distribution. The cumulative distribution below is not in a closed-form, a form that can be computed with finite operations.
Fortunately, the sampling can be done with the Box-Muller Transform. If variable U₁ and U₂ are uniformly distributed, the sampled Z₀ and Z₁ below will have a standard normal distribution.
However, unless the distribution is in some special form, sampling is not easy in general. For example, many ML models are built with Bayes inference. Even we can model the prior and likelihood with well-known distributions, their product is usually not in a form that is easy for sampling anymore.
Inference
Let’s change the discussion for a moment. As shown before, given a probabilistic model for p(x₁, x₂, x₃, …, xn), inferences can be categorized as:
And, these can be solved through marginal probability
and/or MAP (maximum a posteriori) inference.
However, the complexity of a model grows exponentially with the number of variables. If we cannot discover enough independence to simplify the Graphical model, finding an exact solution is infeasible. There are two major approximation methods in ML: sampling inferencing and variational inferencing. In this article, we will cover the first one.
Let’s us elaborate the math further for continuous variables.
As shown again, many calculations involved resolved to the calculation of marginal probabilities. In many ML problems or Reinforcement learning, we are also interested in calculating the expectation value of f(x) with x sampled from a distribution p. For example, we want to know the average rewards for taking a certain sequence of actions. In fact, marginalization and expectation can be discussed together in the same context, as marginalization can be easily expressed as expectation.
In sampling inferencing, we collect data sample
to estimate the marginal probability p(x₁, …) in calculating:
That’s make sampling appealing because it is simple to do and to understand. The real challenge of sampling inferencing is sampling efficiency. If we need to collect an exponential amount of data to find enough samples relevant for the marginal probability, the solution fails.
Recap
Let’s have a short recap. For high-dimensional space, most models cannot be simplified enough to be solved exactly. For sampling-based inference, it usually requires a huge amount of sample points for the necessary accuracy. For it to be successful, we need to know:
- how to represent information in a dense model,
- how to sample data from this model, and
- how to sample efficiently.
For this article, we will cover part of the latter two issues.
Monte Carlo Method
In Monte Carlo, where gamblers gather, probability, statistic, and psychology govern. A Monte Carlo method is a fancy name for analyzing a stochastic process through repeated sampling. I am always puzzled why it is called the Monte Carlo method. It turns out Stanislaw Ulam, a mathematician and a physicist, was participating in the Manhattan Project around World War II. Monte Carlo method was developed by Ulam under a secret project codenamed Monte Carlo, which names after the Monte Carlo Casino. One of the application at that time is to simulate the chain reaction of a hydrogen bomb. The scientists want to handle the process in a controlled manner. Modeling such a stochastic process is too complex otherwise.
Monte Carlo Estimation
Let’s see how to compute an expectation with sampling. For the expectation
It can be estimated by drawing N i.i.d. data points from p. The expectation is the average of f(x) with x drawn from p.
This is simple. The estimated expectation will be unbiased with variance inverse proportional to N. According to the Law of the large number, as N increases, the variance decreases, and the estimate approaches the population value.
Generative model
Many probabilistic models, like the Bayes Network, are a generative model. In short, we can draw samples from this model to make inferences.
In forward sampling, we start with nodes that have no dependence and work down the tree in sampling data using the conditional probability at each node. After many iterations, we can use the collected samples to answer queries.
Unnormalized distribution
Let’s introduce some simple terminology first before moving on. In some models, like a graphical model, we compute factors and then normalized the result with a normalization factor Z to make it to a probability distribution p. For the Bayesian Network, the combined factors is already a distribution and Z=1. For a Markov Random Field, we sum up factors for all possible x to compute Z.
We add a hat symbol on p(x) to indicate it is an unnormalized distribution. Since Z sums up all factors over all variables, it is hard to compute. The unnormalized distribution is often preferred over normalized distribution when designing an ML algorithm since we don’t need to compute Z.
Rejection sampling
The concept of rejection sampling is simple. Let’s explain it for cases when p(x) is easy to calculate but sample from p is hard first. We design an easy to sample distribution q. We sample from it and say, the sampled value is xᵢ. We accept xᵢ as a sample with probability p(x)/q(x). Intuitively, the accepted sample will be proportional to p. Indeed, with proof later, our accepted samples will have a distribution similar to p.
Next, we will go through the details with the general scenario which it is hard to sample data from the unnormalized or the normalized p.
To solve the problem, we can choose a surrogate distribution q which is easy to sample. However, we need three more conditions to make it works. First, q will generate samples wherever p will. Mathematically, if p(x) > 0, q(x) > 0 also. Second, we choose the smallest k which k q(x) is greater than unnormalized p(x) for all x. This is not easy, in particular for high-dimensional space. But it can be relaxed to finding the smallest k that we may know or find as long as k q(x) is an upper bound. It would not hurt our solution, just make it less efficient. Third, we accept data from q with probability proportional to the unnormalized distribution below.
Intuitively, we accept data with higher frequency if the unnormalized p is close to the upper bound. Say, the sampled data have a distribution s. The equations below prove that the sampled data has a distribution p. (As long as k q(x) is the upper bound of unnormalized p(x), the acceptance probability will not be clipped in the rejection sampling.)
But what is the sampling efficiency? How much data is rejected during the sampling? It turns out the sampling efficiency is extremely low for high-dimensional space even k q(x) is very close to q(x). Most samples are thrown away and that is why this method is not popular. But it opens ways for other algorithms.
Importance sampling
In rejection sampling, we find k such that k q(x) is an upper bound for unnormalized p(x). Our objective is only finding a nice upper bound. But, what if we take advantage of calculating the normalized p(x), what will it bring? Expectation can be rewritten as
It evaluates f(x) with x sampling from distribution q instead of p. Intuitively, the equation just reweights the expected value to reflect the change in sampling frequency between p and q.
This is called unnormalized importance sampling. If q is easier to sample, we win as long as we know how to compute the corresponding p(x). If f(x) is already calculated in previous iterations with distribution q, we win again, which is happening in some Reinforcement learning methods.
To have the lowest variance for the expectation, we choose
That is we want q* ∝ |f(x)| p(x), i.e. we want q* to reflect the high probability areas of x that have high |f(x)| value (proof). The low variance allows us to draw fewer samples.
In a Bayesian model, p is easy to compute. But how can we extend it to other models including the MRF which only the unnormalized distribution is easy to compute? Let’s reformulate the equation based on unnormalized p instead.
We compute rᵐ with the unnormalized distribution and then used the total sum of rᵐ to normalize the distribution. This avoids computing the usual normalization factor Z. This is called the normalized importance sampling.
But this approach does come with a price. Unnormalized importance sampling is unbiased while the normalized importance is biased. However, normalized importance sampling is likely has less variance.
Likelihood weighting
The sampling strategy so far addresses any queries that we may want. Sometimes, this is like preparing a full-course dinner while we just want a small bite. What if we just want to query p(x|e) for some observed e only, for example, p(Gender=male | zipcode=94232). Should we generate all possible samples or just generate those with matching e (with zip code equals 94232)? If there are many possible values for e, our sampled data will contain many sample points that are irrelevant. Our sampling efficiency will be very low.
Let’s repeat the exercise for observations with only 2 years of experience and zip code equals 94222. We can use the highlighted data below to compute P(Gender | year of exp. = 2, zip code=94222). The result shows 94222 mostly composes of the female population. (Yes, I do fabricate this result to demonstrate an important point later.)
Now, let’s formulate a strategy in generating samples using the forward sampling. But whenever we hit a node with known observations, we force the variable to be the observed value instead of sampling it with the conditional probability of the node. Assume g is observed with value g³ below. For d, we sample values from the conditional probability as usual. But for g, we force it to g³.
Below is one of the possible sample that we may generate. The zip code and the year of experience will match exactly with the observations.
However, there is a serious issue. The generated data have half male and half female that reflects the general population in the forward sampling. These generated values are not generated based on the zip code 94222. In this zip code, 6 out of 7 should be females. To correct the problem, we issue a weight for each sampled data. It is the product of all conditional probabilities of each observed value (zip code and experience) given the parent values in the sample data. In our example, the observed events are the zip code and the year of experience. The weight will be
For example, let’s say zip code depends on gender, and experience depends on years of education and age. w will be calculated as:
Finally, the query for p(Gender | e) will be calculated as the weighted proportion: adding the weights for rows matching the query divided by added weights for all rows.
Sampling importance resampling (SIR)
As recalled, we want q to be as close as p. q below will sample frequently in areas that are not significant for p. Even the estimated result has low variance but it will not be accurate.
Since designing q can be tricky, can we design an algorithm that can sample adaptively? Let’s start with some random distribution q, like a uniform distribution. But for our illustration, we will use a Gaussian. The green dots are the samples generated by q.
Next, we calculate the weight based on the importance sampling (p(x)/q(x)). So the green dots within p with have a heavier weight.
Next, instead of using q to generate samples, we use the weight to determine the next sampling distribution. Hence, the sampled data will not shift towards p.
And for future iterations, we continue to use the weights to generate samples for the next round. Conceptually, we can visualize the green dots as particles. We re-weight its size based on the importance sampling weight ratio. In the next iterations, each particle spawns off new particles based on its weighted size. For more information, you can read the particle filter here that is used in the self-driving car technology.
Challenge for Sampling Inference
Many sampling inferences assume p is hard to sample so we use a tractable distribution q for sampling instead. So far, we model q to be as close as p. For example, if p is Gaussian-like, q will be a Gaussian. In most ML problems, the grand master plan for these global structures is missing. And human is not good at high-dimensional thinking or visualization. So we don’t understand p globally and therefore, we cannot derive q.
Maybe, for some ML problems, we should use a bottom-up approach rather than a top-down approach. Indeed, many ML problems try to discover local knowledge to gain global insight. As mentioned before, we often know how to calculate p(x), the probability density at each point. We may start solving the problem with this local information. In fact, SIR in the last section is one good example.
To demonstrate this bottom-up approach, let’s drop off our out-of-town shoppers randomly in SF. At each corner of the street, they look ahead to see which direction seems more interesting. After a while, they should all end up in the shopping districts — a simple but effective way to determine where are the most popular shops.
However, instead of using a deterministic policy at each corner, we use a probability model. This gives us a chance to explore even it looks less promising sometimes. In particular, this scheme allows us to move from a high probability area to another by crossing a low probability region, just we won’t spend too much time on it.
Next, we will describe how to generate sample adaptively based on the last sampled data.
Markov Chain Process
In a Markov Chain process, each node below is a state, for example, we are in the Apple store. The arrow represents the chance of changing to the next state in the next time step. For example, 0.5 chance that we will stay and 0.2 chance that we will go to the Zara.
Such transitions can be represented as a transition matrix A with the next state computed as:
We can start from an initial state, say we are in the apple store, i.e. [1, 0, 0, …]. If some conditions are met, uk will converge to a stationary distribution π and π(xᵢ) will be the probability to be at state xᵢ in the stationary distribution.
i.e., the values in the state matrix u will not change.
Actually, the stationary distribution does not depend on the initial state. Detail mathematical reasoning can be found here. But it may be easier to understand from a spouse perspective. No matter where you drop off your spouse for shopping, you should have a reasonable estimate on his/her possible whereabouts later. Don’t underestimate the power of Markov Chain process beyond finding your spouse. Google started its web page ranking algorithm (PageRank) by solving the exact problem, just with a pretty large matrix A to cover the pages on the Internet.
So when a stationary distribution may exist. The two sufficient conditions are
- Irreducibility: You can get from any state to another state. The diagram below is not irreducible. We cannot reach B from A or vice versa. If you start from C, you can have two very different distribution for your state. Irreducibility forbids us to be trapped in a subset of states forever.
- Aperiodicity: Starting from any state, you can reach the same state in any timestep after some warmup. The following is not aperiodicity. If we start from A, we can reach back A in even time step only. Technically, aperiodicity means you have a period of one. Otherwise, the state distribution for the next time step will not equal to the current distribution.
Another sufficient condition, called detailed balance, for the stationary distribution is
If we can find π to satisfy the condition above, π will be the stationary distribution as shown below.
For Google PageRank, A contains the transition probability from one web page to another for the Internet. Its content is weighted by the number of hyperlinks that link to a page. Google solves the problem by continuing multiplying this huge sparse matrix A until u converges. (at least this is done when the company founded.) How to multiply two large sparse matrices is one coding interview question I heard at Google. While it is a nice scalability challenge, it can be solved through sampling too. We use the transition probability to sample the next state. After some warm-up, the collected states will represent the distribution π well. But once again, we assume p is easy to sample from. Let’s expand the idea without this assumption.
Markov Chain Monte Carlo
In Markov Chain Monte Carlo (MCMC), we adopt the Markov chain concept in sampling the next state (a.k.a. sample) based on the current state. But we generalize the concept by introducing a tractable proposal distribution g in determining which state to sample next. So we don’t require p to be easily sampled. However, we will check whether to accept the new sample based on some acceptance test that is related to p. If rejected, we sample from the last sample point again. But, the acceptance test is optional. If g is related to p, such acceptance test may not be needed. MCMC is a general concept without dictating what is the proposal distribution or the optional acceptance test. So to understand it, it will be better to detail some of its specific methods instead.
Metropolis-Hastings algorithm
Metropolis-Hastings algorithm is one of the most popular MCMC methods. First, let’s pick a proposal distribution g that determines which state to explore next. One possibility is the normal distribution. i.e. we pick a neighbor with chances decreased exponentially as it moves away from the current state. It is simple and takes advantage of the possibilities that we may be already in high probability density region of p after some warmup. Other choices can be used but we will stick with the normal distribution for our demonstration.
Next, we establish an acceptance probability A related to p. It determines whether we will select the candidate state for sampling or select another neighbor again. Here, we will accept the new sample x’ with the probability A below.
P(x’)/g(x’|x) is like the importance weight for x’ in SIR while P(x)/g(x|x’) is like the importance weight for x. So A estimates how likely should we include x’ in sampling p. Normal distribution is symmetrical, i.e. g(x|x’) = g(x’|x). So the acceptance probability A will drop below 1 if the new state x’ is less likely than the current state x according to p. If g is not symmetry, we give better chance to transitions that can be undone easier, g(x’|x) < g(x|x’).
Let’s check whether we can use a normal distribution as a proposal distribution to achieve a stationary distribution. Recall that the sufficient condition is we can reach one state from another and the transition has a period of one. For a normal distribution, we can always reach one state from another with non-zero probability. To satisfy the second condition, if the current state is xᵢ, we can make sure it is possible to transit from xᵢ to xᵢ in the next time step. This is true for normal distribution. As shown below, if the number of sample points increases, the distribution of the sample will resemble the distribution p. So we can use the collected sample to calculate the expectation or to approximate p.
In our example, we use a symmetrical proposal distribution and this will be called the Metropolis algorithm (instead of Metropolis-Hastings algorithm). Intuitively, sampling xᵢ from xⱼ will have the same chance in the opposite direction. The acceptance probability A will be simplified to
Gibbs sampling
Designing a proposal distribution can be tricky. Gibbs sampling is another MCMC method without a specially designed proposal. Let assume each data point composes of D variables. Assume that it is hard to sample from the joint probability p
but it is not hard if we fixed all but one variable,
i.e. the conditional distributions above are tractable and easy to sample. Then, instead of sampling all D values in a data sample at once, we just sample one component only. For the next data sample, we sample another component. So the next data point is different from the previous one by one component only.
Hence, we don’t need to create a proposal distribution independent of p. The algorithm will be:
Here is another slightly different implementation.
Gibbs sampler is a special case for the Metropolis-Hastings algorithm. We don’t need the acceptance test because it will simplify to one.
On the left diagram below, we plot the sampled points for the first 50 iterations. This example is in 2-dimension space (D=2). In each iteration, we change one dimension only. Therefore, it moves vertically and horizontally only in each alternating steps. The example samples 1000 points from the original data distribution p and 1000 sampled data from the Gibbs sampling. As expected, both plots look similar.
Shortcoming
Sampling-based methods can suffer a few issues. Since it is not an optimization method, there is no objective function. The key disadvantage is we have no measurement on the progress or how far we are from the solution. In addition, many sampling methods require a proposal distribution which may not be easy to design.
Next
Now we cover the sampling inferencing. Variational inferencing is another major approximation. More information can be found in this article.
For those interested in AlphaGo Zero: