用PyTorch实现多层网络(给代码截图参考)
- 引入模块,读取数据
- 构建计算图(构建网络模型)
- 损失函数与优化器
- 开始训练模型
- 对训练的模型预测结果进行评估
引入模块
import torch
import numpy as np
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
%matplotlib inline
数据转化为tensor
x_train_tensor=torch.from_numpy(x_train)
x_test_tensor=torch.from_numpy(x_test)
y_train_numpy=np.array(y_train)
y_train_tensor=torch.from_numpy(y_train_numpy)
y_test_numpy=np.array(y_test)
y_test_tensor=torch.from_numpy(y_test_numpy)
x=x_train_tensor.float()
y=y_train_tensor.float()
构建网络模型
构建一个六层的网络
class module_net(nn.Module):
def __init__(self, num_input, num_hidden, num_output):
super(module_net, self).__init__()
self.layer1 = nn.Linear(num_input, num_hidden)
self.layer2 = nn.Tanh()
self.layer3 = nn.Linear(num_hidden, num_hidden)
self.layer4 = nn.Tanh()
self.layer5 = nn.Linear(num_hidden, num_hidden)
self.layer6 = nn.Tanh()
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
损失函数与优化器
criterion=nn.BCEWithLogitsLoss()
mo_net = module_net(8, 10, 1)
optim = torch.optim.SGD(mo_net.parameters(), 0.01, momentum=0.9)
模型训练
Loss_list = [] #用来装loss值,以便之后画图
Accuracy_list = [] #用来装准确率,以便之后画图
for e in range(10000):
out = mo_net.forward(Variable(x)) #这里省略了 mo_net.forward()
loss = criterion(out, Variable(y))
Loss_list.append(loss.data[0])
#--------------------用于求准确率-------------------------#
out_class=(out[:]>0).float() #将out矩阵中大于0的转化为1,小于0的转化为0,存入a中
right_num=torch.sum(y==out_class).float() #分类对的数值
precision=right_num/out.shape[0] #准确率
#--------------------求准确率结束-------------------------#
Accuracy_list.append(precision)
optim.zero_grad()
loss.backward()
optim.step()
if (e + 1) % 1000 == 0:
print('epoch: {}, loss: {},precision{},right_num{}'.format(e+1, loss.data[0],precision,right_num))
plt.plot(x1, Loss_list,c='red',label='loss')
plt.plot(x1, Accuracy_list,c='blue',label='precision')
plt.legend()
结果: