用Pytorch实现线性回归
一、概念
上图主要介绍了使用Pytorch解决问题的四个步骤:
1.准备数据集
#1.准备数据集
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
2.使用Class设计模型
#2.使用Class设计模型
class LinearModel(torch.nn.Module):
#构造函数,初始化
def __init__(self):
super(LinearModel,self).__init__()
self.liner = torch.nn.Linear(1,1) #(1,1)表示输入和输出的维度
#前向传播函数
#forward()相当于对父类_init_()进行重载
def forward(self,x):
y_pred = self.liner(x)
return y_pred
model = LinearModel() #创建类LinearModel的实例
我们的模型类应继承自nn.Module(所有神经网络模块的基类)。成员方法__init __()和forward()必须实现。nn.Linear类包含两个Tensor:权重和偏差。nn.Linear类已实现magic method __call __(),该方法使类的实例可以像函数一样被调用。 通常将调用forward()。
其中关于class torch.nn.Linear的用法:
3.构建损失函数和优化器的选择
#3.构建损失函数和优化器的选择
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
这一部分主要使用pytorch进行损失函数以及优化器的选择
4.进行训练的迭代。
#4.进行训练迭代
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch,loss.item()) #.item()获取数值大小
#由.backward()计算的grad将被累积。 因此,在反向传播之前,记住将梯度设置为0。
optimizer.zero_grad()
loss.backward()
optimizer.step() #进行更新update
二、完整代码如下:
import torch
#loss 必须是一个标量才可以用backward。
#1.准备数据集
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
#2.使用Class设计模型
class LinearModel(torch.nn.Module):
#构造函数,初始化
def __init__(self):
super(LinearModel,self).__init__()
self.liner = torch.nn.Linear(1,1) #(1,1)表示输入和输出的维度
#前向传播函数
#forward()相当于对父类_init_()进行重载
def forward(self,x):
y_pred = self.liner(x)
return y_pred
model = LinearModel() #创建类LinearModel的实例
#3.构建损失函数和优化器的选择
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
#4.进行训练迭代
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch,loss.item()) #.item()获取数值大小
#由.backward()计算的grad将被累积。 因此,在反向传播之前,记住将梯度设置为0。
optimizer.zero_grad()
loss.backward()
optimizer.step() #进行更新update
#输出权重w和偏置b
print('w=',model.liner.weight.item())
print('b=',model.liner.bias.item())
#测试模型
x_test = torch.Tensor([4.0])
y_test = model(x_test)
print('y_pred = ',y_test.data)
运行结果:
参考学习链接:https://www.bilibili.com/video/BV1Y7411d7Ys?t=1699&p=5