Pytorch_optimizer_comparison

优化器类型比较
优化器分为SGD,Momentum,RMSprop,Adam四类主流分类器
通过批分类加神经网络训练检验四种分类器的效果
1.导入函数库

import torch
import matplotlib.pyplot.as plt
import torch.utils.data as Data
from torch.autograd import Variable
import torch.nn.functional as F

2.生成伪数据

torch.manual_seed(1)
BATCH_SIZE=32
EPOCH=12
LR=0.01
x=torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y=x.pow(2)+0.1*torch.normal(torch.zeros(x.size()))
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

生成数据点图在这里插入图片描述
转化为批处理的数据格式

torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(
	dataset=torch_dataset,
	batch_size=BATCH_SIZE,
	shuffle=True,
	num_workers=2,
)

搭建网络
第一种搭建方式

class Net(torch.nn.Module):
	def __init__(self):
		super(Net,self).__init__()
		self_hidden=torch.nn.Linear(1,20)
		self_predict=torch.nn.Linear(20,1)
	def forward(self,x):
		x=F.ReLU(self_hidden(x))
		x=self_predict(x)
		return x		

第二种搭建方式

Net=torch.nn.Sequential(	
	torch.nn.Linear(1,20)
	torch.nn.ReLU()
	torch.nn.Linear(20,1)
	)

分为四种网络

net_SGD=Net()
net_Momentum=Net()
net_RMSprop=Net()
net_Adam=Net()

分为四种优化器

opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.9)
opt_RSMprop=torch.optim.PSMprop(net_RSMprop.parameters(),lr=LR,alpha=0.9)
opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
optimizers=[opt_SGD,opt_Momentum,opt_RSMprop,opt_Adam]
loss_func=torch.nn.MSELoss()
losses_his=[[],[],[],[]]# 记录 training 时不同神经网络的 loss

批训练并优化网络

if __name__=='__main__':
	for epoch in range(EPOCH):
		print('epoch: ',epoch)
		for step,(batch_,batch_y) in enumerate(loader):
		b_x=Variable(batch_x)
		b_y=Variable(batch_y)
		for net,opt,l_his in zip(nets,optimizers,losses_his):
			output=net(b_x)
			loss=loss_func(output,b_y)
			opt.zerograd()
			loss.backward()
			opt.step()
			l_his.append(loss.data.numpy())

生成图片

labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
for i, l_his in enumerate(losses_his):
    plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.savefig('comparison.tif')
plt.show()

此处解释什么是enumerate??
enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值
例如

a=[1,2,5,3,9]
for i,value in enumerate(a):
    print('i=',i,'  value=',value)
 
结果:
i= 0   value= 1
i= 1   value= 2
i= 2   value= 5
i= 3   value= 3
i= 4   value= 9

生成的图片,如下:
可知Adam优化器误差最小
可知Adam优化器误差最小

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值