pytorch系列(六):各种优化器的性能比较

import torch
import torch.utils.data as Data
import torch.nn.functional as f
import matplotlib.pyplot as plt

#指定超参数
LR=0.01#学习率
BATCH_SIZE=32#批数据的大小
EPOCH=12#迭代次数

#构造数据集
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.1*(torch.normal(torch.zeros(*x.size())))


#打印数据
plt.scatter(x.data.numpy(),y.data.numpy(),c='r')
plt.show()

#使用dataloader工具进行数据的处理
torch_dataset=Data.TensorDataset(x,y)#将x和y转换成torch可识别的数据集
loader=Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

#构造网络结构并为每一个优化器优化一个神经网络
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(1,20)
        self.output=torch.nn.Linear(20,1)
    #前向传播
    def forward(self,x):
        x=f.relu(self.hidden(x))
        x=self.output(x)
        return x

#每一个优化器对应一个网络结构
net_SGD=Net()
net_Momentum=Net()
net_RMSprop=Net()
net_Adam=Net()

#放到一个列表中
nets=[net_SGD,net_Momentum,net_RMSprop,net_Adam]


#API化每一个优化器
opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.8)
opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))

#用一个列表存放每一个优化器
optimizers=[opt_SGD,opt_Momentum,opt_RMSprop,opt_Adam]
#指定损失函数
loss_func=torch.nn.MSELoss()
#用一个两层列表记录各个优化器的loss
loss_his=[[],[],[],[]]

#训练  可视化
for epoch in range(EPOCH):
    print(epoch)
    for step,(batch_x,batch_y) in enumerate(loader):
        
        #对于每一个优化器,优化他的神经网络
        for net,opt,l_his in zip(nets,optimizers,loss_his):
            output=net(batch_x)#对每一个网络丢入数据
            loss=loss_func(output,batch_y)#计算预测值和真实值之间的误差
            opt.zero_grad()#梯度清零
            loss.backward()#反向传播
            opt.step()#更新每一个参数
            l_his.append(loss.data.numpy())

#可视化
lables=["SGD","Momentum","RMSprop","Adam"]
for i,l_his in enumerate(loss_his):#enumerate是列举,会迭代列表的中的每一个索引和每一项的值
    plt.plot(l_his,label=lables[i])
plt.legend(loc=1)#legend是做一个图例说明  loc=1表示放在右边  详情看参数,label=lables[i]相对应
plt.xlabel("steps")
plt.ylabel("loss")
plt.ylim((0,0.5))
plt.show()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值