使用Pytorch和Pyro实现贝叶斯神经网络(Bayesian Neural Network)

本文翻译自:https://www.noconote.work/entry/2019/01/01/160100

最近概率模型和神经网络相结合的研究变得多了起来,这次使用Uber开源的Pyro来实现一个贝叶斯神经网络。概率编程框架最近出了不少,Uber的Pyro基于Pytorch,Google的Edward基于TensorFlow,还有一些独立的像PyMC3,Stan,Pomegranate等等。
这次选择使用Pyro实现的原因是,Pyro基于Numpy实现,加上对Pytorch的支持很友好,可以自动计算梯度,动态计算,这些好处当然都是Pytorch带来的。官网链接:https://pyro.ai/, 介绍到此为止,接下来开始一步一步写代码了。

Step1:导入需要的模块

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms

import pyro
from pyro.distributions import Normal
from pyro.distributions import Categorical
from pyro.optim import Adam
from pyro.infer import SVI
from pyro.infer import Trace_ELBO

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

Step2:准备数据集,这次使用的是MNIST

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST("/tmp/mnist", train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),])),
        batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST("/tmp/mnist", train=False, transform=transforms.Compose([transforms.ToTensor(),])

Step3:定义神经网络模型

class BNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output

net = BNN(28*28, 1024, 10)

Step4:基于正态分布,初始化weights和bias

def model(x_data, y_data):
    # define prior destributions
    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight),
                                      scale=torch.ones_like(net.fc1.weight))
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias),
                                      scale=torch.ones_like(net.fc1.bias))
    outw_prior = Normal(loc=torch.zeros_like(net.out.weight),
                                       scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias),
                                       scale=torch.ones_like(net.out.bias))
    
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior, 
        'out.weight': outw_prior,
        'out.bias': outb_prior}
    
    lifted_module = pyro.random_module("module", net, priors)
    lifted_reg_model = lifted_module()
    
    lhat = F.log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

Step5:为了近似后验概率分布,先定义一个函数

def guide(x_data, y_data):
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = F.softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)

    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = F.softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = F.softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)

    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = F.softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior,
        'out.weight': outw_prior,
        'out.bias': outb_prior}
    
    lifted_module = pyro.random_module("module", net, priors)
    
    return lifted_module()

Step6:定义Optimizer
*这里多说两句,优化器还是喜闻乐见的Adam;这个SVI函数是干嘛的呢,它使用变分推理(Variational inference)来逼近后验概率分布,说到变分推理自然离不开ELBO,这里的Trace_ELBO函数是继承的父类ELBO,还有其他种类的ELBO,比如JitTrace_ELBO,TraceGraph_ELBO等等,具体选哪个,去看文档就好了。顺便推荐一个变分推理的教程,徐亦达老师的视频教程非常棒。

optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

Step7:训练Bayesian Neural Network

n_iterations = 5
loss = 0

for j in range(n_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        loss += svi.step(data[0].view(-1,28*28), data[1])

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch ", j, " Loss ", total_epoch_loss_train)

Step8:使用训练好的模型,预测结果
*Bayesian Neural Network要分别针对每个类别进行采样,然后计算预测值中概率最大的类别,作为结果。这里有10个数字,所以我们设置采样的次数为10

n_samples = 10

def predict(x):
    sampled_models = [guide(None, None) for _ in range(n_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images.view(-1,28*28))
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: %d %%" % (100 * correct / total))

*当然也可以把计算出的概率归一化一下,这样使得结果更具有解释性

def predict_prob(x):
    sampled_models = [guide(None, None) for _ in range(n_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return mean
#正则化
def normalize(x):
    return (x - x.min()) / (x.max() - x.min())
#可视化
def plot(x, yhats):
    fig, (axL, axR) = plt.subplots(ncols=2, figsize=(8, 4))
    axL.bar(x=[i for i in range(10)], height= F.softmax(torch.Tensor(normalize(yhats.numpy()))[0]))
    axL.set_xticks([i for i in range(10)], [i for i in range(10)])
    axR.imshow(x.numpy()[0])
    plt.show()

效果如下:

x, y = test_loader.dataset[0]
yhats = predict_prob(x.view(-1, 28*28))
print("ground truth: ", y.item())
print("predicted: ", yhats.numpy())
plot(x, yhats)

ground truth:  7
predicted:  [[  39.87391   -86.72453   112.25181   102.81404   -89.95473    52.947186
  -167.88455   402.66757   -29.799454   95.75331 ]]

在这里插入图片描述
结束了。

评论 43
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值