# 实现 重要性采样
import matplotlib.pyplot as plt
# import random
import numpy as np
# from scipy.stats import norm, poisson, binom, uniform
from scipy.special import logsumexp
class Pdf:
def __call__(self, x):
# 返回密度函数log值
pass
def sample(self, n):
# 在对应分布中抽取n个样本
pass
class Norm(Pdf):
def __init__(self, mu=0, sigma=1):
self.mu = mu
self.sigma = sigma
def __call__(self, x):
transferred_x = ((x - self.mu) / self.sigma) ** 2 * (-0.5)
return transferred_x
def sample(self, n):
sample_list = np.random.normal(self.mu, self.sigma, n)
return sample_list
class Uniform(Pdf):
def __init__(self, low, high):
self.low = low
self.high = high
def __call__(self, x):
# U(low,high)的概率密度函数是个常数,不依赖于x
# 也即是所有的x,其pdf都是-np.log(self.high-self.low,故需要repeat n次数
transferred_x = np.repeat(-np.log(self.high - self.low), len(x))
# 如果q不是均匀分布,也即q的pdf依赖于x的取值,则只用transferred_x = np.log(q(x)),不影响主函数中的 p = np.exp(...,.., p = exp(logws))
return transferred_x
def sample(self, n):
sample_list = np.random.uniform(self.low, self.high, n)
return sample_list
class ImportantSampler:
def __init__(self, p_dist, q_dist):
self.p_dist = p_dist
self.q_dist = q_dist
def sample(self, n):
samples = self.q_dist.sample(n) # q_dist和p_dist被视作类内定义的实例
weights = self.calc_weights(samples)
# 将wi(加权数)标准化(softmax方法)
norm_weights = weights - logsumexp(weights) #此处logsumexp函数,防止数值下溢
return samples, norm_weights
def calc_weights(self, samples):
# log(wi) = log(p(X)/q(x)) = log(p(x))-log(q(x))
logws = self.p_dist(samples) - self.q_dist(samples)
return logws
if __name__ == '__main__':
N = 200000
target_p = Norm() # 默认mu=0,sigma=1
imq_q = Uniform(-10, 30)
sampler = ImportantSampler(target_p, imq_q) # sampler是创建的实例,用于重要性采样
# biased_sampler是通过重要性采样从q_dist中获取的样本
biased_sampler, logws = sampler.sample(N)
# p是每个样本被选中的概率,(p(x)/q(x))*q(x),不考虑常数项
# 下面这行独立于p与q的pdf,p是array,对于每一个x,都有对应的w(x),x本身服从q(x),因此这行的结果是q(x)*w(x)
samples = np.random.choice(biased_sampler, N, p=np.exp(logws)) # 默认有放回取值,即replace= True
# bins = 20,设置直方块的个数为20
plt.hist(samples, bins=20)
plt.show()
重要性采样
最新推荐文章于 2024-07-28 15:46:11 发布
![](https://img-home.csdnimg.cn/images/20240711042549.png)