pytorch实现线性回归

import torch
import numpy as np
import matplotlib.pyplot as plt

# prepare the train set
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
cost_list =[]
epoch_list = []

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()

# 损失函数和优化器
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#training
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    cost_list.append(loss.item())
    epoch_list.append(epoch)

# output weight and bias
print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

# test model
x_test = torch.Tensor([4.0])
y_test = model(x_test)
print('y_pred=', y_test.data)


# 绘图
plt.plot(epoch_list, cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()

在这里插入图片描述
在这里插入图片描述

1 34.9356689453125
2 15.894685745239258
3 7.413261413574219
4 3.6327247619628906
5 1.9449564218521118
6 1.1888999938964844
7 0.8476817607879639
8 0.6912060379981995
9 0.6170368194580078
10 0.5795725584030151
11 0.5585132241249084
12 0.5448195934295654
13 0.5344666242599487
14 0.5256625413894653
15 0.517608106136322
16 0.5099464654922485
17 0.5025186538696289
18 0.49525266885757446
19 0.48811572790145874
20 0.48109155893325806
21 0.4741734266281128
22 0.46735700964927673
23 0.4606401026248932
24 0.45401960611343384
25 0.4474937915802002
26 0.4410627484321594
27 0.4347241222858429
28 0.4284762740135193
29 0.4223185181617737
30 0.4162491261959076
31 0.4102667570114136
32 0.4043706953525543
33 0.3985593318939209
34 0.3928312361240387
35 0.38718563318252563
36 0.38162100315093994
37 0.37613654136657715
38 0.370730996131897
39 0.36540308594703674
40 0.36015161871910095
41 0.3549755811691284
42 0.3498741686344147
43 0.34484604001045227
44 0.3398899734020233
45 0.3350050449371338
46 0.33019065856933594
47 0.32544538378715515
48 0.3207683563232422
49 0.3161582350730896
50 0.31161460280418396
51 0.3071361780166626
52 0.30272218585014343
53 0.29837164282798767
54 0.2940835952758789
55 0.289856880903244
56 0.28569120168685913
57 0.2815854549407959
58 0.2775387465953827
59 0.2735496759414673
60 0.26961857080459595
61 0.2657439112663269
62 0.26192474365234375
63 0.258160263299942
64 0.2544500529766083
65 0.2507934272289276
66 0.24718916416168213
67 0.2436366081237793
68 0.24013523757457733
69 0.2366839349269867
70 0.2332824021577835
71 0.2299298197031021
72 0.22662559151649475
73 0.2233685702085495
74 0.22015826404094696
75 0.2169942557811737
76 0.21387580037117004
77 0.21080203354358673
78 0.2077721357345581
79 0.20478622615337372
80 0.20184344053268433
81 0.19894257187843323
82 0.19608312845230103
83 0.19326522946357727
84 0.1904878318309784
85 0.18775013089179993
86 0.18505176901817322
87 0.18239237368106842
88 0.17977109551429749
89 0.1771874725818634
90 0.17464110255241394
91 0.17213137447834015
92 0.16965728998184204
93 0.1672191172838211
94 0.1648162305355072
95 0.16244745254516602
96 0.1601126492023468
97 0.1578115075826645
98 0.15554358065128326
99 0.1533082127571106
w= 1.739339828491211
b= 0.5925418734550476
y_pred= tensor([7.5499])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qq_38621899

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值