【PyTorch】Linear Regression 用PyTorch实现线性回归

实现步骤为以下四步:

  1. Prepare dataset
  2. Design model using Class:inherit from nn.Module
  3. Construct loss and optimizer:using PyTorch API
  4. Training cycle:forward,backward,update

代码实现:

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self) -> None:
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()

criterion = torch.nn.MSELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

 运行结果:

955 5.003892624699802e-07
956 4.930132604386017e-07
957 4.860560238739708e-07
958 4.787866600963753e-07
959 4.7221078602888156e-07
960 4.653236942431249e-07
961 4.5884127075623837e-07
962 4.5185811359260697e-07
963 4.457418185666029e-07
964 4.3908954694416025e-07
965 4.3306090447003953e-07
966 4.267702138349705e-07
967 4.2060179339387105e-07
968 4.1470161704637576e-07
969 4.0854621374819544e-07
970 4.0269520695801475e-07
971 3.966299573221477e-07
972 3.911557087121764e-07
973 3.854668761960056e-07
974 3.7989167367413756e-07
975 3.74534749880695e-07
976 3.694633505801903e-07
977 3.6400547287485097e-07
978 3.587275045902061e-07
979 3.5335011716597364e-07
980 3.4842446439142805e-07
981 3.4353348610238754e-07
982 3.3850847103167325e-07
983 3.33721118295216e-07
984 3.290019776613917e-07
985 3.242496973143716e-07
986 3.192694748577196e-07
987 3.1468636052522925e-07
988 3.1026576152726193e-07
989 3.0571646902899374e-07
990 3.0139085538394284e-07
991 2.9731859285675455e-07
992 2.9286502467584796e-07
993 2.8860108614026103e-07
994 2.8446152100514155e-07
995 2.8035196919518057e-07
996 2.7642477107292507e-07
997 2.7234369781581336e-07
998 2.686840616661357e-07
999 2.6469047043065075e-07
w =  1.9996576309204102
b =  0.000778441084548831
y_pred =  tensor([[7.9994]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值