马尔可夫链有两个要素:
- 一步转移概率矩阵:
- 初始分布:
如果这两个要素都确定了,这个链的转移行为就被完全确定下来了。我们就可以求得极限分布 ,只需解下面这个方程即可。
但是MCMC试图解决的问题刚好是反过来。即已知极限分布 ,如何求得一步转移概率矩阵 。
如果这件事情能够做到,而我们的目标是想要获得服从分布 的伪随机数。那么我们只需要找到一个符合条件的矩阵 ,然后运行这条马氏链直到达到稳态,从这一时刻之后起,生成的每一个随机数都可以当做我们的采样样本。
下面,我们将分三步来说明解算这个问题的方法:
Step1. 什么是细致平衡方程(Detailed Balance)
如果满足如下方程:
我们管这样的关系叫做细致平衡:当你取得平衡之后,各个状态之间的相互的转移行为是处于一种平衡态上的,即转移出去的量和转移进来的量是相等的。也可以看成是在做某种平衡的物质交换。
满足细致平衡方程的马氏链一定处于稳态:。注意,这是一个充分条件,证明如下:
Step2. 任取一个一步转移概率矩阵作为一个 Proposal
Step3. 对 Proposal 的矩阵进行如下的校正(Correction)
不难发现满足细致平衡方程:,于是 也就是我们要找的那个 。
这个方法的名字叫做 Metropolis-Hastings。
其实只要基于细致平衡方程,我们不难从物质交换的角度理解 Step3 的校正所干的事情:
如果 所转移物质,超过了 所转移的物质,即 ,那么 Metropolis-Hastings 的策略就是:乘以这一项 使之达到和 一样的物质交换水平,即:
但是我学的时候一直存在一个疑问,就是 的值校正之后可能会减小,那么 是否还满足一步转移概率矩阵行和为1的条件呢?据说只要将减小的损失部分补到对角线上就可以了。但是这一点我始终搞不太明白为什么。因为教材普遍都是用接受拒绝的实现策略一带而过。
对于连续版本的Proposal矩阵,其实就是选择一个分布,我们可以选择高斯分布,如此一来,转移概率就等价于:
这里摘录西瓜书中的算法实现细节:
- 代码实现:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import norm, gamma
# 目标分布为标准正态分布
target_dist = norm(10, 1)
def M_H(target_dist, proposal_std, num_samples = 10000, warm_up_samples = 5000):
'''
warm_up_samples = 5000 # warm-up阶段的样本数量
'''
# 初始化采样链
initial_state = 0
samples = [initial_state]
cnt = 0
# 开始M-H采样
for i in range(num_samples):
current_state = samples[-1]
# 从建议分布中抽样新样本
proposed_state = np.random.normal(current_state, proposal_std)
# 计算接受概率
acceptance_prob = min(1, target_dist.pdf(proposed_state) / target_dist.pdf(current_state))
# 接受/拒绝新样本
if np.random.rand() < acceptance_prob:
samples.append(proposed_state)
else:
samples.append(current_state)
cnt += 1
return samples[warm_up_samples:], cnt
# 设置M-H算法的参数
proposal_std = 1.0
num_samples = 40000
warm_up_samples = 20000
samples, cnt = M_H(target_dist, proposal_std, num_samples, warm_up_samples)
# 输出采样结果
print("M-H采样结果:")
print(samples)
print(f"丢失率 = {cnt / num_samples}")
sns.histplot(samples, bins=50, kde=True, color='skyblue', edgecolor='black', linewidth=0.5)
你可能会有疑问,为什么acceptance_prob会是直接两个目标分布的概率相除呢?因为对于对称分布Normal来说,从转移到和从转移到的概率是一样的,所以分子分母消去了。
- 采样结果:
丢失率 = 0.293475