Reference:
https://mbernste.github.io/posts/elbo/
https://mbernste.github.io/posts/variational_inference/
https://mbernste.github.io/posts/em/
https://mbernste.github.io/posts/gmm_em/
Theodoridis S. Machine learning: a Bayesian and optimization perspective[M]. Academic press, 2015.
Content
Latent variable model
- We posit that our observed data x x x is a realization from some random variable X X X.
- We posit the existence of another random variable Z Z Z where X X X and Z Z Z are distributed according to a joint distribution p ( X , Z ; θ ) p(X,Z;\theta) p(X,Z;θ) where θ \theta θ parameterizes the distribution.
- Our data is only a realization of X X X, not Z Z Z, and therefore Z Z Z remains unobserved (i.e. latent).
There are two predominant tasks that we may be interested in accomplishing:
- Given some fixed value for θ \theta θ, compute the posterior distribution p ( Z ∣ X ; θ ) p(Z|X;\theta) p(Z∣X;θ) [Can be solved by variational inference]
- Given that θ \theta θ is unknown, find the maximum likelihood estimate of θ \theta θ [Can be solved by EM]
Both variational inference and EM rely on the ELBO.
The Evidence Lower Bound (ELBO)
What is the ELBO?
To understand the evidence lower bound, we must first understand what we mean by “evidence”: it is just a name given to the likelihood function evaluated at a fixed θ \theta θ
evidence : = log p ( x ; θ ) (ELBO.1) \text{evidence}:=\log p(x;\theta) \tag{ELBO.1} evidence:=logp(x;θ)(ELBO.1)
Why is this quantity called the “evidence”?
Intuitively, if we have chosen the right model p p p and θ \theta θ, then we would expect that the marginal probability of our observed data x x x, would be high. Thus, a higher value of log p ( x ; θ ) \logp(x;θ) logp(x;θ) indicates, in some sense, that we may be on the right track with the model p p p and parameters θ \theta θ that we have chosen. That is, this quantity is “evidence” that we have chosen the right model for the data.
If we happen to know (or posit) that Z Z Z follows some distribution denoted by q q q, s.t.
p ( x , z ; θ ) : = p ( x ∣ z ; θ ) q ( z ) (ELBO.2) p(x,z;\theta):=p(x|z;\theta)q(z) \tag{ELBO.2} p(x,z;θ):=p(x∣z;θ)q(z)(ELBO.2)
Then the evidence lower bound is just a lower bound on the evidence that makes use of the known q q q. Specifically,
log p ( x ; θ ) ≥ E Z ∼ q [ log p ( x , Z ; θ ) q ( Z ) ] (ELBO.3) \log p(x;\theta)\ge E_{Z\sim q}\left[\log \frac{p(x,Z;\theta)}{q(Z)} \right] \tag{ELBO.3} logp(x;θ)≥EZ∼q[logq(Z)p(x,Z;θ)](ELBO.3)
where the ELBO is simply the right-hand side of the above equation:
E L B O : = E Z ∼ q [ log p ( x , Z ; θ ) q ( Z ) ] (ELBO.4) ELBO:=E_{Z\sim q}\left[\log \frac{p(x,Z;\theta)}{q(Z)} \right] \tag{ELBO.4} ELBO:=EZ∼q[logq(Z)p(x,Z;θ)](ELBO.4)
Derivation
Jensen’s Inequality: if X X X is a random variable and φ \varphi φ is a convex (concave) function, then
φ ( E [ X ] ) ≤ ( ≥ ) E [ φ ( X ) ] \varphi(E[X])\le (\ge) E[\varphi(X)] φ(E[X])≤(≥)E[φ(X)]
Since log ( ⋅ ) \log (\cdot) log(⋅) is a concave function,
E Z ∼ q [ log p ( x , Z ; θ ) q ( Z ) ] ≤ log [ E Z ∼ q ( p ( x , Z ; θ ) q ( Z ) ) ] = log [ ∫ q ( Z ) p ( x , Z ; θ ) q ( Z ) d z ] = log p ( x ; θ ) \begin{aligned} E_{Z\sim q}\left[\log \frac{p(x,Z;\theta)}{q(Z)} \right]&\le \log \left[E_{Z\sim q}\left( \frac{p(x,Z;\theta)}{q(Z)}\right) \right]\\ &=\log \left[\int q(Z)\frac{p(x,Z;\theta)}{q(Z)}dz \right]\\ &=\log p(x;\theta) \end{aligned} EZ∼q[logq(Z)p(x,Z;θ)]≤log[EZ∼q(q(Z)p(x,Z;θ))]=log[∫q(Z)q(Z)p(x,Z;θ)dz]=logp(x;θ)
The gap between the evidence and the ELBO
It turns out that the gap between the evidence and the ELBO is precisely the Kullback Leibler divergence between q ( Z ) q(Z) q(Z) and p ( Z ∣ x ; θ ) p(Z|x;θ) p(Z∣x;θ).
evidence − ELBO : = log p ( x ; θ ) − E Z ∼ q [ log p ( x , Z ; θ ) q ( Z ) ] = K L ( q ( Z ) ∥ p ( Z ∣ x ; θ ) ) (ELBO.5) \text{evidence}-\text{ELBO}:=\log p(x;\theta)-E_{Z\sim q}\left[\log \frac{p(x,Z;\theta)}{q(Z)} \right] =KL(q(Z)\|p(Z|x;\theta)) \tag{ELBO.5} evidence−ELBO:=logp(x;θ)−EZ∼q[logq(Z)p(x,Z;θ)]=KL(q(Z)∥p(Z∣x;θ))(ELBO.5)
This fact forms the basis of the [variational inference algorithm] for approximate Bayesian inference.
Derivation
log p ( x ; θ ) − E Z ∼ q [ log p ( x , Z ; θ ) q ( Z ) ] = ∫ q ( Z ) log p ( x ; θ ) d z − ∫ q ( Z ) log p ( x , Z ; θ ) q ( Z ) d z = ∫ q ( Z ) log p ( x ; θ ) q ( Z ) p ( x , Z ; θ ) d z = ∫ q ( Z ) log q ( Z ) p ( Z ∣ x ; θ ) d z = K L ( q ( Z ) ∥ p ( Z ∣ x ; θ ) ) \begin{aligned} \log p(x;\theta)-E_{Z\sim q}\left[\log \frac{p(x,Z;\theta)}{q(Z)} \right] &= \int q(Z)\log p(x;\theta)dz-\int q(Z)\log \frac{p(x,Z;\theta)}{q(Z)}dz\\ &= \int q(Z) \log \frac{p(x;\theta) q(Z)}{p(x,Z;\theta)}dz\\ &= \int q(Z) \log \frac{q(Z)}{p(Z|x;\theta)}dz\\ &= KL(q(Z)\|p(Z|x;\theta)) \end{aligned} logp(x;θ)−EZ∼q[logq(Z)p(x,Z;θ)]=∫q(Z)logp(x;θ)dz−∫q(Z)logq(Z)p(x,Z;θ)dz=∫q(Z)logp(x,Z;θ)p(x;θ)q(Z)dz=∫q(Z)logp(Z∣x;θ)q(Z)dz=KL(q(Z)∥p(Z∣x;θ))
Variational Inference
Why variational inference?
Variational inference is a paradigm for estimating a posterior distribution when computing it explicitly is intractable.
Assume that we have a model that involves hidden random variables Z Z Z, observed random variables X X X, and some posited probabilistic model over the hidden and the observed random variables P ( X , Z ) P(X,Z) P(X,Z). The goal is to compute the posterior distribution P ( Z ∣ X ) P(Z|X) P(Z∣X).
Ideally, we would do so by using Bayes theorem:
p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) p(z|x)=\frac{p(x|z)p(z)}{p(x)} p(z∣x)=p(x)p(x∣z)p(z)
In practice, it is often difficult to compute p ( z ∣ x ) p(z|x) p(z∣x) via Bayes theorem because the denominator p ( x ) p(x) p(x) does not have a closed form. Usually, the denominator p ( x ) p(x) p(x) can be only be expressed as an integral that marginalizes over z z z: p ( x ) = ∫ p ( x , z ) d z p(x)=\int p(x,z) dz p(x)=∫p(x,z)dz. In such scenarios, we’re often forced to approximate p ( z ∣ x ) p(z|x) p(z∣x) rather than compute it directly. Variational inference is one such approximation technique.
Details
Instead of computing p ( z ∣ x ) p(z|x) p(z∣x) exactly via Bayes theorem, variational inference attempts to find another distribution q ( z ) q(z) q(z) that is ‘close’ to p ( z ∣ x ) p(z|x) p(z∣x), where the ‘closeness’ is measured by the KL-divergence.
K L ( q ( Z ) ∥ p ( Z ∣ x ) ) = ∫ q ( Z ) log q ( Z ) p ( Z ∣ x ) d z = E Z ∼ q [ log q ( Z ) p ( Z ∣ x ) ] (VI.1) KL(q(Z)\|p(Z|x))=\int q(Z) \log \frac{q(Z)}{p(Z|x)}dz=E_{Z\sim q}\left[ \log \frac{q(Z)}{p(Z|x)}\right] \tag{VI.1} KL(q(Z)∥p(Z∣x))=∫q(Z)logp(Z∣x)q(Z)dz=EZ∼q[log