I had the hardest time trying to understand variational inference. All of the presentations I’ve seen (MacKay, Bishop, Wikipedia, Gelman’s draft for the third edition of Bayesian Data Analysis) are deeply tied up with the details of a particular model being fit. I wanted to see the algorithm and get the big picture before being overwhelmed with multivariate exponential family gymnastics.
Bayesian Posterior Inference
In the Bayesian setting (see my earlier post, What is Bayesian Inference?), we have a joint probability model for data and parameters , usually factored as the product of a likelihood and prior term, . Given some observed data , Bayesian predictive inference is based on the posterior density of the the unknown parameter vector given observed data vector . Thus we need to be able to estimate the posterior density to carry out Bayesian inference. Note that the posterior is a whole density function—we’re not just after a point estimate as in maximum likelihood estimation.
Mean-Field Approximation
Variational inference approximates the Bayesian posterior density with a (simpler) density parameterized by some new parameters . The mean-field form of variational inference factors the approximating density by component of , as
.
I’m going to put off actually defining the terms until we see how they’re used in the variational inference algorithm.
What Variational Inference Does
The variational inference algorithm finds the value for the parameters of the approximation which minimizes the Kullback-Leibler divergence of from ,
.
The key idea here is that variational inference reduces posterior estimation to an optimization problem. Optimization is typically much faster than approaches to posterior estimation such as Markov chain Monte Carlo (MCMC).
The main disadvantage of variational inference is that the posterior is only approximated (though as MacKay points out, just about any approximation is better than a delta function at a point estimate!). In particular, variational methods systematically underestimate posterior variance because of the direction of the KL divergence that is minimized. Expectation propagation (EP) also converts posterior fitting to optimization of KL divergence, but EP uses the opposite direction of KL divergence, which leads to overestimation of posterior variance.
Variational Inference Algorithm
Given the Bayesian model , observed data , and functional terms making up the approximation of the posterior , the variational inference algorithm is:
-
-
- .
-
The inner expectation is a function of returning a single non-negative value, defined by
Despite the suggestive factorization of and the coordinate-wise nature of the algorithm, variational inference does not simply approximate the posterior marginals independently.
Defining the Approximating Densities
The trick is to choose the approximating factors so that the we can compute parameter values such that . Finding such approximating terms given a posterior is an art form unto itself. It’s much easier for models with conjugate priors. Bishop or MacKay’s books and the Wikipedia present calculations for a wide range of exponential-family models.
What if My Model is not Conjugate?
Unfortunately, I almost never work with conjugate priors (and even if I did, I’m not much of a hand at exponential-family algebra). Therefore, the following paper just got bumped to the top of my must understand queue:
- Wang, Chong and David M. Blei. 2012–2013. Variational Inference in Nonconjugate Models. arXiv 1209.4360.
It’s great having Dave down the hall on sabbatical this year — one couldn’t ask for a better stand in for Matt Hoffman. They are both insanely good at on-the-fly explanations at the blackboard (I love that we still have real chalk and high quality boards).