变分推断(Variational Inference)

从变分推断(Variational Inference)说起

   在贝叶斯体系中,推断(inference) 指的是 利用已知变量x的观测值推测未知变量z的后验分布,即我们在已经输入变量x后,如何获得未知变量z的分布p(z|x)[3].通俗一点讲一个完整的故事就是,如果没有任何信息,我们可能大概了解一个(latent)变量z的分布,这个分布可能方差比较大。变量x是可观察的,并含有z的一些信息。那么在观察到x后,关于z的分布(此时是后验分布p(z|x))会发生变化,比如方差变得更小了,如下图所示。
在这里插入图片描述
  利用贝叶斯公式:[4]
在这里插入图片描述
p ( x ∣ z ) p(x|z) p(xz) p ( z ) p(z) p(z)可以做出必要的假设符合某个分布。 p ( x ) p(x) p(x)是已经观察到的,所以称为证据(evidence)。
变分推断的一般步骤:
在这里插入图片描述
  精确推断方法准确地计算 p ( z ∣ x ) p(z|x) p(zx),该过程往往需要很大的计算开销,现实应用中近似推断更为常用。近似推断的方法往往分为两大类:

  • 第一类是采样,常见的是MCMC方法,
  • 第二类是使用另一个分布近似 p ( z ∣ x ) p(z|x) p(zx),典型代表就是变分推断。变分推断可以是推断后验分布的期望或者方差。

近似变分推断,就是要找到一个分布 q ∗ ( z ) q^*(z) q(z)去近似后验分布 p ( z ∣ x ) p(z|x) p(zx)

  • 指定一个关于z的分布族Q
  • 找到一个 q ∗ ( z ) ∈ Q q^*(z) \in Q q(z)Q去近似 p ( z ∣ x ) p(z|x) p(zx)
    在这里插入图片描述
    其中L是一种度量,可以度量两个分布分近似程度。Variational Bayes(变分贝叶斯,VB) 的这个度量采用KL距离:
    在这里插入图片描述
      KL距离,是Kullback-Leibler差异(Kullback-Leibler Divergence)的简称,也叫做相对熵(RelativeEntropy)。它衡量的是相同事件空间里的两个概率分布的差异情况。其物理意义是:在相同事件空间里,概率分布P(x)的事件空间,若用概率分布 Q ( x ) Q(x) Q(x)编码时,平均每个基本事件(符号)编码长度增加了多少比特。[5]
    在这里插入图片描述
    这里对KL的意义再重点讨论一下:
      KL的意义其实也很好理解。现在假如有两个概率分布P(x)和Q(x),现在要看看分布Q(x)与分布P(x)的接近程度。怎么做呢?其实很容易能够想到,就是根据分布P(x)中采样N个数: x 1 , x 2 , . . . , x N x_1,x_2,...,x_N x1,x2,...,xN,看 P ( x 1 ) P ( x 2 ) . . . P ( x N ) Q ( x 1 ) Q ( x 2 ) . . . Q ( x N ) \frac{P(x_1)P(x_2)...P(x_N)}{Q(x_1)Q(x_2)...Q(x_N)} Q(x1)Q(x2)...Q(xN)P(x1)P(x2)...P(xN)与1
    的接近程度,如果取对数就是 l o g ( P ( x 1 ) Q ( x 1 ) ) + l o g ( P ( x 2 ) Q ( x 2 ) ) + . . . + l o g ( P ( x N ) Q ( x N ) ) log(\frac{P(x_1)}{Q(x_1)})+log(\frac{P(x_2)}{Q(x_2)})+...+log(\frac{P(x_N)}{Q(x_N)}) log(Q(x1)P(x1))+log(Q(x2)P(x2))+...+log(Q(xN)P(xN))与0的接近程度,取平均数: 1 N ( l o g ( P ( x 1 ) Q ( x 1 ) ) + l o g ( P ( x 2 ) Q ( x 2 ) ) + . . . + l o g ( P ( x N ) Q ( x N ) ) ) \frac{1}{N}(log(\frac{P(x_1)}{Q(x_1)})+log(\frac{P(x_2)}{Q(x_2)})+...+log(\frac{P(x_N)}{Q(x_N)})) N1(log(Q(x1)P(x1))+log(Q(x2)P(x2))+...+log(Q(xN)P(xN))),这个就是对 ∑ x ∈ X P ( x ) l o g ( P ( x ) Q ( x ) ) \sum_{x \in X}P(x)log(\frac{P(x)}{Q(x)}) xXP(x)log(Q(x)P(x))的估计。因为是看分布Q(x)与分布P(x)的接近程度,所以是从P(x)取样,如果是看分布P(x)与分布Q(x)的接近程度,那就是从Q(x)取样,那么就是 ∑ x ∈ X Q ( x ) l o g ( Q ( x ) P ( x ) ) \sum_{x \in X}Q(x)log(\frac{Q(x)}{P(x)}) xXQ(x)log(P(x)Q(x)),二者肯定是不一样的,所以KL距离不是对称的。

   这样做下面的一步推导,其中 p ( z ∣ x ) p(z|x) p(zx)是未知的(本质上, p ( x ) p(x) p(x)是未知的)。
在这里插入图片描述
这样做进一步的变化,把 p ( z ∣ x ) p(z|x) p(zx)拆开。
在这里插入图片描述
关于ELBO(evidence lower bound,evidence下界):

  • ELBO是evidence的对数的下界
  • K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) KL(q(z)||p(z|x)) KL(q(z)∣∣p(zx))的最小化等价于对ELBO的最大化(做这样的转化是必要的,因为ELBO中是没有未知量的)
  • ELBO中的 p ( x , z ) = p ( x ∣ z ) p ( z ) p(x,z)=p(x|z)p(z) p(x,z)=p(xz)p(z) p ( x ∣ z ) p(x|z) p(xz) p ( z ) p(z) p(z)都做了假设,是知道的,所以 p ( x , z ) p(x,z) p(x,z)是知道的。

在这里插入图片描述

参考链接:

[3]:变分推断(Variational Inference)初探 - 知乎 (zhihu.com)

[4]: 变分推断-1贝叶斯统计 详细推导 Variational Inference「机器学习」_哔哩哔哩_bilibili
[5]: 计算 KL距离 (相对熵)

https://zhuanlan.zhihu.com/p/507776434

https://zhuanlan.zhihu.com/p/340226815

https://zhuanlan.zhihu.com/p/57574493

https://blog.csdn.net/m0_58547949/article/details/127963524

https://kexue.fm/archives/5253

  • 24
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
作为一个AI语言模型,我并不具备编写Matlab程序的能力。但是我可以为您提供一些变分推断的Matlab程序示例,供您参考。 1. 高斯混合模型的变分推断 %生成数据 n = 1000; % 数据量 k = 3; % 高斯分量个数 d = 2; % 数据维度 % 生成高斯分量的均值和协方差矩阵 mu = rand(k,d)*10-5; % 均值矩阵 sigma = zeros(d,d,k); % 协方差矩阵 for i=1:k sigma(:,:,i) = diag(rand(d,1)*5); % 对角线矩阵 end % 生成数据 data = zeros(n,d); for i=1:n k_index = randi(k); data(i,:) = mvnrnd(mu(k_index,:),sigma(:,:,k_index)); end % 变分推断 % 初始化参数 alpha = ones(1,k)./k; % 高斯分量的先验概率 mu_init = rand(k,d)*10-5; % 均值矩阵 sigma_init = zeros(d,d,k); % 协方差矩阵 for i=1:k sigma_init(:,:,i) = diag(rand(d,1)*5); % 对角线矩阵 end q_alpha = alpha; % 变分分布的先验概率 q_mu = mu_init; % 均值矩阵 q_sigma = sigma_init; % 协方差矩阵 % 迭代计算 max_iter = 100; for iter=1:max_iter % 更新q_alpha E_z = zeros(n,k); for i=1:n for j=1:k E_z(i,j) = log(alpha(j)) + log_mvnpdf(data(i,:),q_mu(j,:),q_sigma(:,:,j)); end E_z(i,:) = exp(E_z(i,:) - max(E_z(i,:))); % 防止指数爆炸 E_z(i,:) = E_z(i,:) ./ sum(E_z(i,:)); end q_alpha = alpha + sum(E_z,1); % 更新q_mu for j=1:k q_mu(j,:) = sum(E_z(:,j).*data,1) ./ sum(E_z(:,j)); end % 更新q_sigma for j=1:k diff = data - q_mu(j,:); q_sigma(:,:,j) = diff'*diag(E_z(:,j))*diff ./ sum(E_z(:,j)); end end % 计算后验概率 posterior = zeros(n,k); for i=1:n for j=1:k posterior(i,j) = log(q_alpha(j)) + log_mvnpdf(data(i,:),q_mu(j,:),q_sigma(:,:,j)); end posterior(i,:) = exp(posterior(i,:) - max(posterior(i,:))); % 防止指数爆炸 posterior(i,:) = posterior(i,:) ./ sum(posterior(i,:)); end % 显示结果 figure; hold on; scatter(data(:,1),data(:,2),10,posterior(:,1),'filled'); scatter(q_mu(:,1),q_mu(:,2),100,'k','filled'); scatter(mu(:,1),mu(:,2),100,'r','filled'); hold off; title('GMM with Variational Inference'); legend('Cluster 1','Cluster 2','Cluster 3'); xlabel('Feature 1'); ylabel('Feature 2'); 2. 隐马尔可夫模型的变分推断 % 生成数据 n = 1000; % 数据量 k = 3; % 隐状态个数 d = 2; % 数据维度 % 生成隐状态转移矩阵和观测矩阵 A = rand(k,k); % 隐状态转移矩阵 A = A ./ sum(A,2); B = rand(k,d)*10-5; % 观测矩阵 % 生成数据 data = zeros(n,d); z = zeros(n,1); z(1) = randi(k); data(1,:) = mvnrnd(B(z(1),:),eye(d)); for i=2:n z(i) = randsample(k,1,true,A(z(i-1),:)); data(i,:) = mvnrnd(B(z(i),:),eye(d)); end % 变分推断 % 初始化参数 alpha = ones(1,k)./k; % 隐状态的先验概率 A_init = rand(k,k); % 隐状态转移矩阵 A_init = A_init ./ sum(A_init,2); B_init = rand(k,d)*10-5; % 观测矩阵 q_alpha = alpha; % 变分分布的先验概率 q_A = A_init; % 隐状态转移矩阵 q_B = B_init; % 观测矩阵 % 迭代计算 max_iter = 100; for iter=1:max_iter % 更新q_alpha E_z = zeros(n,k); for i=1:n for j=1:k E_z(i,j) = log(alpha(j)) + log(A(z(i-1),j)) + log_mvnpdf(data(i,:),q_B(j,:),eye(d)); end E_z(i,:) = exp(E_z(i,:) - max(E_z(i,:))); % 防止指数爆炸 E_z(i,:) = E_z(i,:) ./ sum(E_z(i,:)); end q_alpha = alpha + E_z(1,:); % 更新q_A for i=1:k for j=1:k q_A(i,j) = sum(E_z(1:end-1,i).*E_z(2:end,j)) ./ sum(E_z(1:end-1,i)); end end % 更新q_B for j=1:k q_B(j,:) = sum(E_z(:,j).*data,1) ./ sum(E_z(:,j)); end end % 计算后验概率 posterior = zeros(n,k); for i=1:n for j=1:k if i==1 posterior(i,j) = log(q_alpha(j)) + log_mvnpdf(data(i,:),q_B(j,:),eye(d)); else posterior(i,j) = log(A(z(i-1),j)) + log_mvnpdf(data(i,:),q_B(j,:),eye(d)); end end posterior(i,:) = exp(posterior(i,:) - max(posterior(i,:))); % 防止指数爆炸 posterior(i,:) = posterior(i,:) ./ sum(posterior(i,:)); end % 显示结果 figure; hold on; scatter(data(:,1),data(:,2),10,posterior(:,1),'filled'); scatter(q_B(:,1),q_B(:,2),100,'k','filled'); scatter(B(:,1),B(:,2),100,'r','filled'); hold off; title('HMM with Variational Inference'); legend('State 1','State 2','State 3'); xlabel('Feature 1'); ylabel('Feature 2');

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值