import torch
from torch import nn,optim
import random
import numpy as np
import matplotlib.pyplot as plt
#1、准备数据和参数
input_size = 3
output_size = 1
num_epochs = 2000
learning_rate = 0.001
#随机数据
x_train=torch.randn(10,input_size) #可以调整维度
y_train=2*torch.randn(10,output_size)
#2、构造模型,用现成的
model=nn.Linear(input_size,output_size) #线性模型
criterion=nn.MSELoss() #损失函数
optimizer=optim.Adam(model.parameters(),lr=learning_rate) #优化函数
#自制迭代器
def myenumerate(s_num,e_num):
while s_num<e_num:
yield s_num
s_num+=1
#3、训练模型
#plt.ion() #动态图开始
for epoch in myenumerate(0,num_epochs):
predit=model(x_train) #预测数据
loss=criterion(y_train,predit) #计算损失
optimizer.zero_grad() #梯度置零
loss.backward() #反向求导数
optimizer.step() #模型更新
if (epoch+1)%50==0:
plt.cla() #清除图像
plt.plot(np.arange(y_train.shape[0]),y_train,'ro',label='original labels')
plt.plot(np.arange(y_train.shape[0]),predit.detach().numpy(),'go',label='predit labels')
plt.text(y_train.shape[0]/2,np.min(y_train.detach().numpy()+0.1),'epoch:{}/{},loss:{}'.format(epoch,num_epochs,loss.item().__round__(2)))
plt.legend()
plt.pause(0.2)
#plt.ioff() #动态图结束
plt.show()
我学Pytorch之二~~~Pytorch最简单的线性模型,绝对适合入门
最新推荐文章于 2022-10-24 21:55:43 发布