PyTorch深度学习实践第五讲——Pytorch实现线性回归

步骤:

1.Prepare dataset(准备数据集),仍然采用前几节构造的简单数据集,但注意这里x和y都是一列的数据。

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

 

2.Design model using Class(设计模型,inherit from nn.Module,计算y_pred),创建线性模型类,从nn.Module类继承,必须要实现构造函数并重写forward()方法。

其中,构造函数,用到torch.nn.Linear类创建linear对象(linear对象包含w和b),需要指定输入输出的维度。

CLASS torch.nn.Linear(in_features, out_features, bias=True)

in_features – size of each input sample

out_features – size of each output sample

bias – If set to False, the layer will not learn an additive bias. Default: True

class LinearModel(torch.nn.Module):
    def __init__(self): #构造函数
        super(LinearModel, self).__init__() #使用父类的构造函数
        self.linear = torch.nn.Linear(1,1) #构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b
    def forward(self, x):
        y_pred = self.linear(x)#可调用对象,计算y=wx+b
        return y_pred

forward()函数中要注意,linear是构造函数中创建的对象,linear后直接加参数,说明linear是一个类似于函数的可调用对象,可调用对象要求必须实现__call__()方法,__call__()需要调用forward(),所以必须要重写forward()函数。

3.Construct Loss and Optimizer(构建损失函数和优化器,using Pytorch API)

model = LinearModel() #实例化模型
criterion = torch.nn.MSELoss(reduction='sum') #损失函数
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

4.Traing cycle(实现训练周期,即forward, backward, update)

for epoch in range(100):
    y_pred = model(x_data)  #第一步,计算y_pred
    loss = criterion(y_pred, y_data) #第二步,计算损失
    print(epoch, loss.item()) 
    
    optimizer.zero_grad() 
    loss.backward() #第三步,backward
    optimizer.step() #第四步,更新参数

print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

使用SGD、Adagrad、Adam、Adamax、ASGD、RMSprop、Rprop 同优化器的loss图。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值