【PyTorch深度学习五】线性回归

# Prepare dataset
import torch
import matplotlib.pyplot as plt
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# Design a model
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1,1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()

# Construct Loss and Optimizer
criterion = torch.nn.MSELoss(reduction = 'sum')
optimizer = torch.optim.Rprop(model.parameters(), lr = 0.01)

Loss_list = []
epoch_list = []

# Training Cycle
for epoch in range(100):
    y_pred = model(x_data)
    Loss = criterion(y_pred, y_data)

    Loss_list.append(Loss.item())
    epoch_list.append(epoch)

    print(epoch, Loss)
    optimizer.zero_grad()
    Loss.backward()
    optimizer.step()


print('w= ', model.linear.weight.item())
print('b= ', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)

print('y_pred= ', y_test.data)

plt.plot(epoch_list, Loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Cost in each epoch')
plt.show()

结果:

0 tensor(28.7887, grad_fn=<MseLossBackward0>)
1 tensor(28.2193, grad_fn=<MseLossBackward0>)
2 tensor(27.5438, grad_fn=<MseLossBackward0>)
3 tensor(26.7441, grad_fn=<MseLossBackward0>)
4 tensor(25.8004, grad_fn=<MseLossBackward0>)
5 tensor(24.6909, grad_fn=<MseLossBackward0>)
6 tensor(23.3923, grad_fn=<MseLossBackward0>)
7 tensor(21.8814, grad_fn=<MseLossBackward0>)
8 tensor(20.1367, grad_fn=<MseLossBackward0>)
9 tensor(18.1412, grad_fn=<MseLossBackward0>)
10 tensor(15.8882, grad_fn=<MseLossBackward0>)
11 tensor(13.3885, grad_fn=<MseLossBackward0>)
12 tensor(10.6823, grad_fn=<MseLossBackward0>)
13 tensor(7.8575, grad_fn=<MseLossBackward0>)
14 tensor(5.0764, grad_fn=<MseLossBackward0>)
15 tensor(2.6156, grad_fn=<MseLossBackward0>)
16 tensor(0.9246, grad_fn=<MseLossBackward0>)
17 tensor(0.7127, grad_fn=<MseLossBackward0>)
18 tensor(0.7127, grad_fn=<MseLossBackward0>)
19 tensor(0.5708, grad_fn=<MseLossBackward0>)
20 tensor(0.5717, grad_fn=<MseLossBackward0>)
21 tensor(0.4721, grad_fn=<MseLossBackward0>)
22 tensor(0.3920, grad_fn=<MseLossBackward0>)
23 tensor(0.3365, grad_fn=<MseLossBackward0>)
24 tensor(0.2800, grad_fn=<MseLossBackward0>)
25 tensor(0.2230, grad_fn=<MseLossBackward0>)
26 tensor(0.1672, grad_fn=<MseLossBackward0>)
27 tensor(0.1184, grad_fn=<MseLossBackward0>)
28 tensor(0.0894, grad_fn=<MseLossBackward0>)
29 tensor(0.0680, grad_fn=<MseLossBackward0>)
30 tensor(0.0479, grad_fn=<MseLossBackward0>)
31 tensor(0.0399, grad_fn=<MseLossBackward0>)
32 tensor(0.0299, grad_fn=<MseLossBackward0>)
33 tensor(0.0212, grad_fn=<MseLossBackward0>)
34 tensor(0.0140, grad_fn=<MseLossBackward0>)
35 tensor(0.0104, grad_fn=<MseLossBackward0>)
36 tensor(0.0066, grad_fn=<MseLossBackward0>)
37 tensor(0.0041, grad_fn=<MseLossBackward0>)
38 tensor(0.0036, grad_fn=<MseLossBackward0>)
39 tensor(0.0019, grad_fn=<MseLossBackward0>)
40 tensor(0.0010, grad_fn=<MseLossBackward0>)
41 tensor(0.0005, grad_fn=<MseLossBackward0>)
42 tensor(0.0006, grad_fn=<MseLossBackward0>)
43 tensor(1.8464e-05, grad_fn=<MseLossBackward0>)
44 tensor(0.0020, grad_fn=<MseLossBackward0>)
45 tensor(0.0020, grad_fn=<MseLossBackward0>)
46 tensor(0.0005, grad_fn=<MseLossBackward0>)
47 tensor(5.1644e-05, grad_fn=<MseLossBackward0>)
48 tensor(5.1644e-05, grad_fn=<MseLossBackward0>)
49 tensor(7.8988e-05, grad_fn=<MseLossBackward0>)
50 tensor(7.8988e-05, grad_fn=<MseLossBackward0>)
51 tensor(1.7733e-05, grad_fn=<MseLossBackward0>)
52 tensor(6.9844e-05, grad_fn=<MseLossBackward0>)
53 tensor(6.9844e-05, grad_fn=<MseLossBackward0>)
54 tensor(2.6660e-05, grad_fn=<MseLossBackward0>)
55 tensor(2.0058e-05, grad_fn=<MseLossBackward0>)
56 tensor(2.0058e-05, grad_fn=<MseLossBackward0>)
57 tensor(1.7194e-05, grad_fn=<MseLossBackward0>)
58 tensor(1.8231e-05, grad_fn=<MseLossBackward0>)
59 tensor(1.5087e-05, grad_fn=<MseLossBackward0>)
60 tensor(1.3668e-05, grad_fn=<MseLossBackward0>)
61 tensor(1.4037e-05, grad_fn=<MseLossBackward0>)
62 tensor(1.2818e-05, grad_fn=<MseLossBackward0>)
63 tensor(1.1866e-05, grad_fn=<MseLossBackward0>)
64 tensor(1.1301e-05, grad_fn=<MseLossBackward0>)
65 tensor(1.0518e-05, grad_fn=<MseLossBackward0>)
66 tensor(9.8766e-06, grad_fn=<MseLossBackward0>)
67 tensor(8.8999e-06, grad_fn=<MseLossBackward0>)
68 tensor(7.9224e-06, grad_fn=<MseLossBackward0>)
69 tensor(6.8427e-06, grad_fn=<MseLossBackward0>)
70 tensor(5.6809e-06, grad_fn=<MseLossBackward0>)
71 tensor(4.4789e-06, grad_fn=<MseLossBackward0>)
72 tensor(3.3119e-06, grad_fn=<MseLossBackward0>)
73 tensor(2.3130e-06, grad_fn=<MseLossBackward0>)
74 tensor(1.7230e-06, grad_fn=<MseLossBackward0>)
75 tensor(1.2697e-06, grad_fn=<MseLossBackward0>)
76 tensor(8.7076e-07, grad_fn=<MseLossBackward0>)
77 tensor(7.2190e-07, grad_fn=<MseLossBackward0>)
78 tensor(5.1235e-07, grad_fn=<MseLossBackward0>)
79 tensor(3.4525e-07, grad_fn=<MseLossBackward0>)
80 tensor(2.1836e-07, grad_fn=<MseLossBackward0>)
81 tensor(1.5809e-07, grad_fn=<MseLossBackward0>)
82 tensor(8.2677e-08, grad_fn=<MseLossBackward0>)
83 tensor(4.4136e-08, grad_fn=<MseLossBackward0>)
84 tensor(4.6352e-08, grad_fn=<MseLossBackward0>)
85 tensor(1.2817e-08, grad_fn=<MseLossBackward0>)
86 tensor(8.7319e-08, grad_fn=<MseLossBackward0>)
87 tensor(8.7319e-08, grad_fn=<MseLossBackward0>)
88 tensor(2.7766e-08, grad_fn=<MseLossBackward0>)
89 tensor(1.5190e-08, grad_fn=<MseLossBackward0>)
90 tensor(1.5190e-08, grad_fn=<MseLossBackward0>)
91 tensor(1.3445e-08, grad_fn=<MseLossBackward0>)
92 tensor(1.3445e-08, grad_fn=<MseLossBackward0>)
93 tensor(1.2285e-08, grad_fn=<MseLossBackward0>)
94 tensor(1.2185e-08, grad_fn=<MseLossBackward0>)
95 tensor(1.1031e-08, grad_fn=<MseLossBackward0>)
96 tensor(9.8694e-09, grad_fn=<MseLossBackward0>)
97 tensor(8.5395e-09, grad_fn=<MseLossBackward0>)
98 tensor(7.1282e-09, grad_fn=<MseLossBackward0>)
99 tensor(5.6800e-09, grad_fn=<MseLossBackward0>)
w=  1.9999539852142334
b=  9.305930871050805e-05
y_pred=  tensor([[7.9999]])

Process finished with exit code 0

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值