优化器类型比较
优化器分为SGD,Momentum,RMSprop,Adam四类主流分类器
通过批分类加神经网络训练检验四种分类器的效果
1.导入函数库
import torch
import matplotlib.pyplot.as plt
import torch.utils.data as Data
from torch.autograd import Variable
import torch.nn.functional as F
2.生成伪数据
torch.manual_seed(1)
BATCH_SIZE=32
EPOCH=12
LR=0.01
x=torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y=x.pow(2)+0.1*torch.normal(torch.zeros(x.size()))
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()
生成数据点图
转化为批处理的数据格式
torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
搭建网络
第一种搭建方式
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self_hidden=torch.nn.Linear(1,20)
self_predict=torch.nn.Linear(20,1)
def forward(self,x):
x=F.ReLU(self_hidden(x))
x=self_predict(x)
return x
第二种搭建方式
Net=torch.nn.Sequential(
torch.nn.Linear(1,20)
torch.nn.ReLU()
torch.nn.Linear(20,1)
)
分为四种网络
net_SGD=Net()
net_Momentum=Net()
net_RMSprop=Net()
net_Adam=Net()
分为四种优化器
opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.9)
opt_RSMprop=torch.optim.PSMprop(net_RSMprop.parameters(),lr=LR,alpha=0.9)
opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
optimizers=[opt_SGD,opt_Momentum,opt_RSMprop,opt_Adam]
loss_func=torch.nn.MSELoss()
losses_his=[[],[],[],[]]# 记录 training 时不同神经网络的 loss
批训练并优化网络
if __name__=='__main__':
for epoch in range(EPOCH):
print('epoch: ',epoch)
for step,(batch_,batch_y) in enumerate(loader):
b_x=Variable(batch_x)
b_y=Variable(batch_y)
for net,opt,l_his in zip(nets,optimizers,losses_his):
output=net(b_x)
loss=loss_func(output,b_y)
opt.zerograd()
loss.backward()
opt.step()
l_his.append(loss.data.numpy())
生成图片
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
for i, l_his in enumerate(losses_his):
plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.savefig('comparison.tif')
plt.show()
此处解释什么是enumerate??
enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值
例如
a=[1,2,5,3,9]
for i,value in enumerate(a):
print('i=',i,' value=',value)
结果:
i= 0 value= 1
i= 1 value= 2
i= 2 value= 5
i= 3 value= 3
i= 4 value= 9
生成的图片,如下:
可知Adam优化器误差最小