import pyvarinf from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from sklearn import datasets import torch from sklearn.model_selection import train_test_split import numpy as np digits = datasets.load_digits() print(digits.data.shape) print(digits.target.shape) X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.3) X_train = torch.tensor(X_train, requires_grad=True).unsqueeze(0).unsqueeze(0).view(-1, 1, 8, 8) Y_train = torch.Tensor(y_train).long() X_test = torch.tensor(X_test).unsqueeze(0).unsqueeze(0).view(-1, 1, 8, 8) y_test = torch.Tensor(y_test).long() class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=(2, 2)) self.conv2 = nn.Conv2d(10, 20, kernel_size=(2, 2)) self.fc1 = nn.Linear(20, 100) self.fc2 = nn.Linear(100, 10) self.bn1 = nn.BatchNorm2d(10) self.bn2 = nn.BatchNorm2d(20) def forward(self, x): x = self.bn1(F.relu(F.max_pool2d(self.conv1(x), 2))) x = self.bn2(F.relu(F.max_pool2d(self.conv2(x), 2))) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x) model = Net() var_model = pyvarinf.Variationalize(model) var_model.set_prior('gaussian') optimizer = optim.Adam(var_model.parameters(), lr=0.01) var_model.train() for step in range(0, 500): data =Variable(X_train.float()) target = Variable(Y_train) optimizer.zero_grad() output = var_model(data) loss_error = F.nll_loss(output, target) # The model is only sent once, thus the division by # the number of datapoints used to train loss_prior = var_model.prior_loss() / 60000 loss = loss_error + loss_prior loss.backward() optimizer.step() print('step={}, loss={}'.format(step, loss.data)) img = Variable(torch.tensor(X_test, requires_grad=True).float()) out = model(img) result = [] for i in range(0, len(out.data.numpy())): result.append(np.argmax(out[i].data.numpy())) print(result) print(y_test) sum = 0 for i in range(0, len(y_test)): if result[i] == y_test[i]: sum += 1 print(sum / len(result))
贝叶斯网络对手写识别体的预测
最新推荐文章于 2022-09-18 16:55:08 发布