此次培训内容为搭建简单模型。
一、模拟散点数据
1、创建由100个范围为-1~1的数构成的一维张量用于表示x的值,再将x转换为100*1的矩阵;
2、根据公式y=x^4创建y的矩阵值,为了保证y的严谨性和合理性需要增加干扰,由于干扰由torch.rand()产生,所以每次产生的干扰数都不同,输出的值也就不一样;
3、将x和y都转换为numpy数据类型,即将tensor([[......]])转换为[[......]];
4、绘制散点图并可视化;
5、相关值的输出如图7~8所示。
图1 print data
图2 print x 图3 print y(y不加干扰时) 图4 print y(y加干扰时)
图5 print x_np 图6 print y_np
图7 散点图(y不加干扰时)
图8 散点图(y加干扰时)
二、设置模型和超参数
设置简单模型即定义简单神经网络,需要继承nn.Module类。再定义构造方法,构造方法需要给出全连接层、激活函数的设置。接着重写前向传播方法,过程就是用前面定义的各种操作来搭积木。所谓的前向传播算法就是:将上一层的输出作为下一层的输入,并计算下一层的输出,一直到运算到输出层为止。接着设置学习率为0.2、优化器、损失函数。optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用。
推荐一篇介绍反向传播的知乎:
神经网络反向传播参数更新详解 - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/52473487
三、进行训练
for循环300次(0~299):调用模型预测y的值-->计算loss-->参数更新-->可视化-->下一次循环。训练结果如下:
图9 训练结果可视化
四、完整代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#模拟散点数据
#torch.linspace()第一二个参数为start和end,第三个参数为需要生成的点的个数
data = torch.linspace(-1,1,100) #-1~1的100个点构成的一维张量
#torch.unsqueeze()的dim取值范围是[-2,1],-2和0效果一样:矩阵向量;-1和1效果一样:非向量矩阵,如在此代码为100*1矩阵
x = torch.unsqueeze(data, dim=1) #二维张量/扩充维度
#y = x.pow(4) #不增加干扰
y = x.pow(4)+0.2*torch.rand(x.size()) #增加干扰
x_np = x.data.numpy() #tensor数据类型转换为numpy
y_np = y.data.numpy()
plt.scatter(x_np,y_np) #绘制散点图
plt.show()
#设置模型
class Net(nn.Module): #继承nn.Module类
#定义构造方法
def __init__(self):
super(Net,self).__init__() #继承父类/初始化
#nn.Linear(1,10)表示将原来的100*1矩阵转换为100*10矩阵并将数据内容也作处理
self.Linear1=nn.Linear(1,10) #定义全连接层1
self.Relu=nn.ReLU() #定义激活函数
self.Linear2=nn.Linear(10,1) #定义全连接层2
#重写前向传播方法
def forward(self, x):
x=self.Relu(self.Linear1(x))
y_pre=self.Linear2(x)
return(y_pre)
model=Net() #模型加载,本质就是创建Net类的实例对象
#设置超参数
lr = 0.2 #学习率
optimizer = torch.optim.SGD(model.parameters(),lr=lr) #优化器
loss_function = nn.MSELoss() #损失函数
#进行训练
for t in range(300):
y_predict = model(x)
loss=loss_function(y_predict,y)
print(t,loss.item()) #打印训练次数和loss值,loss值越小,效果越好
plt.ion() #开启交互,为了连续显示动画变化效果,否则只能一帧一帧的显示
plt.show()
#没有下面这三行语句,loss值不会发生改变,训练无用
optimizer.zero_grad() #将梯度归零
loss.backward() #反向传播计算参数梯度
optimizer.step() #梯度下降执行参数更新
plt.cla()
plt.scatter(x_np,y_np) #绘制散点图
plt.plot(x_np,y_predict.data.numpy(),'r-',lw=5)#绘制曲线
plt.pause(0.1)