利用Pyro库进行贝叶斯分析

Pyro:由Uber开发的Python库,用于深度概率建模。Pyro提供了灵活的API,并且与PyTorch紧密集成,后者是另一个流行的深度学习库。Pyro支持多种推理算法,包括变分推理和马尔可夫链蒙特卡洛(MCMC)。

import pyro
import pyro.distributions as dist
import torch
import numpy as np
import matplotlib.pyplot as plt
from pyro.infer import MCMC, NUTS

plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示为方块的

# 定义模型
def model(x, y):
    # 先验:权重和噪声的先验分布
    w_prior = dist.Normal(torch.tensor(0.), torch.tensor(10.))
    b_prior = dist.Normal(torch.tensor(0.), torch.tensor(10.))
    sigma_prior = dist.Uniform(torch.tensor(0.), torch.tensor(10.))

    # 随机样本
    w = pyro.sample("weights", w_prior)
    b = pyro.sample("bias", b_prior)
    sigma = pyro.sample("sigma", sigma_prior)

    # 似然函数
    with pyro.plate("data", x.shape[0]):
        pyro.sample("obs", dist.Normal(w * x + b, sigma), obs=y)

# 示例数据
N = 100  # 数据点数量
x_data = torch.linspace(-8, 8, N)
y_data = torch.tensor(1.5 + x_data + torch.randn(N) * 2., dtype=torch.float)

# 设置MCMC
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=10000, warmup_steps=200)
mcmc.run(x_data, y_data)

# 获取MCMC样本
mcmc_samples = mcmc.get_samples()

# 打印MCMC样本的统计信息
print(mcmc_samples["weights"].mean())
print(mcmc_samples["bias"].mean())
print(mcmc_samples["sigma"].mean())

# 可视化MCMC样本
w_samples = mcmc_samples["weights"].detach().numpy()
b_samples = mcmc_samples["bias"].detach().numpy()
sigma_samples = mcmc_samples["sigma"].detach().numpy()

plt.subplot(3, 1, 1)
plt.hist(w_samples, bins=50, density=True)
plt.title("权重 w")

plt.subplot(3, 1, 2)
plt.hist(b_samples, bins=50, density=True)
plt.title("偏差 b")

plt.subplot(3, 1, 3)
plt.hist(sigma_samples, bins=50, density=True)
plt.title("噪声尺度 sigma")

plt.tight_layout()
plt.show()

# 绘制线性回归线和0.95 HDI
x_line = torch.linspace(x_data.min(), x_data.max(), 100)
w_mean = mcmc_samples["weights"].mean()
b_mean = mcmc_samples["bias"].mean()
y_line = w_mean * x_line + b_mean

# 生成预测值的样本
pred_samples = []
for i in range(10000):
    w = mcmc_samples["weights"][i]
    b = mcmc_samples["bias"][i]
    sigma = mcmc_samples["sigma"][i]
    y_pred = dist.Normal(w * x_line + b, sigma).sample()
    pred_samples.append(y_pred)

pred_samples = torch.stack(pred_samples)

# 计算0.95 HDI
pred_lower = torch.quantile(pred_samples, 0.025, dim=0)
pred_upper = torch.quantile(pred_samples, 0.975, dim=0)

# 绘制数据点和线性回归线
plt.scatter(x_data, y_data, label='数据点')
plt.plot(x_line, y_line, label='回归线')

# 绘制0.95 HDI
plt.fill_between(x_line, pred_lower, pred_upper, color='b', alpha=0.2, label='0.95 HDI')
plt.legend()
plt.show()

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值