【PyTorch深度学习】 第四讲:用PyTorch实现线性回归

1.原理

使用PyTorch的内部函数实现线性回归,训练模型

2.代码展示

# author:ZhuYuYing
# data:2021/7/6
# projectName:tor-start
import torch


x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])


'''
__init__()
    初始化函数
torch.nn.Linear
    第一个参数:weight
    第二个参数:bias
'''

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = LinearModel()


'''
torch.nn.MSELoss
    (1) 如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss
    (2) 如果 reduce = True,那么 loss 返回的是标量
    (3) reduction = ‘none’,直接返回向量形式的 loss
    (4) reduction = ‘sum’,返回loss之和
    (5) reduction = ''elementwise_mean,返回loss的平均值
    (6) reduction = ''mean,返回loss的平均值

torch.optim.SGD
    优化函数,model.parameters()为该实例中可优化的参数
    lr为参数优化的选项(学习率等)
'''

criterion = torch.nn.MSELoss(reduction='sum') #计算MSE
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 梯度下降


for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
    optimizer.zero_grad()  # 将梯度初始化为零
    loss.backward()
    optimizer.step()  # update 参数,即更新w和b的值

#训练结果:w,b最佳值
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())


#测试一个数据
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_test = ', y_test.data)
结果打印
0 34.59614181518555
1 15.420919418334961
2 6.884373188018799
3 3.0838661193847656
4 1.3917133808135986
···············
···············
···············
997 1.991135434309399e-08
998 1.9568119569157716e-08
999 1.9304664533592586e-08

w =  2.0000925064086914
b =  -0.0002103645383613184

y_test =  tensor([[8.0002]])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值