03 pytorch01 最简单的回归_哔哩哔哩_bilibili
注意类型
提高精确度
1、增加训练次数
2、加深网络层数
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def show():
print(x)
print(y)
plt.scatter(x.numpy(),y.numpy())
plt.show()
if __name__ == '__main__':
'''
目标:预测y=2*x+1
训练一个模型
给定一个x,输出一个y
'''
# 1 得到数据
x = torch.linspace(1,10,200)
# bug 1*200 200*1
x = x.reshape(200,1)
y = 2*x+1
show()
# 2 定义模型 不推荐复杂的要定义类,这里比较简单
myModel = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
print(myModel)
# 定义损失函数
loss_function = torch.nn.MSELoss()
# 定义梯度下降算法 新手不建议调参数
optimizer=torch.optim.Adam(myModel.parameters())
# 训练一个模型
num_epochs = 5000
# 模型训练过程
for epoch in range(num_epochs):
optimizer.zero_grad()
forward = myModel(x)
loss_value = loss_function(forward,y)
loss_value.backward()
optimizer.step()
print(f'epoch:{epoch+1},loss:{loss_value:f}')
# 模型预测
myModel.eval()
with torch.no_grad():
output = myModel(torch.tensor([[1],[2]],dtype= torch.float))
print(output)