【PyTorch】深度学习实战之PyTorch实现线性回归

PyTorch实现线性回归

可调用对象:

如果要使用一个可调用对象,那么在类的声明的时候要定义一个call函数

class Foobar: 
    def __init__(self):
        pass
    def __call__(self,*args,**kwargs):
        pass

其中参数*args代表把前n个参数变成n元组,**kwargsd会把参数变成一个词典,这些都是python的基础语法

def func(*args,**kwargs):
    print(args)
    print(kwargs)
    
func(1,2,3,4,x=3,y=5)
"""
(1, 2, 3, 4)
{'x': 3, 'y': 5}
"""

PyTorch线性回归的四个过程:

  • 准备训练集
  • 使用类设计模型(目的是为了前向传播forward,计算y hat)
  • 构造损失函数和优化器(其中,loss是为了进行反向传播,optimizer是为了更新梯度
  • 循环训练(前向算损失,反向算梯度,然后不断更新)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HF6SQqKx-1647356737901)(../AppData/Roaming/Typora/typora-user-images/image-20220315105324246.png)]

每一次训练的过程就是:

  • 前向传播,求y_hat(预测值)
  • 根据y_hat和y_label(y_data)计算loss
  • 反向传播backward(计算梯度)
  • 根据梯度,更新参数

实现代码:

import torch
import matplotlib.pyplot as plt
import numpy as np

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])


class LinearModel(torch.nn.Module):
    def __init__(self):  # 构造函数
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # 构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b

    def forward(self, x):
        y_pred = self.linear(x)  # 可调用对象,计算y=wx+b
        return y_pred


model = LinearModel()  # 实例化模型

criterion = torch.nn.MSELoss(reduction='sum')
# model.parameters()会扫描module中的所有成员,如果成员中有相应的权重,那么都会将结果加到要训练的集合参数上
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # lr为学习率

epoch_list = []
loss_list = []
# for epoch in np.arange(0, 100, 2):
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_list.append(epoch)
    loss_list.append(loss.item())
print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
plt.plot(epoch_list, loss_list)
plt.xlabel('times')
plt.ylabel('loss')
plt.title('SGD')
plt.show()

运行结果:

0 111.91926574707031
1 49.82788848876953
2 22.186500549316406
3 9.881272315979004
4 4.403273582458496
5 1.9645596742630005
6 0.8788504600524902
7 0.3954624831676483
8 0.18021120131015778
9 0.08432696759700775
10 0.04158348590135574
11 0.022497136145830154
12 0.013943195343017578
13 0.01007873099297285
14 0.008302716538310051
15 0.007457221858203411
16 0.007026821840554476
17 0.006781961768865585
18 0.006620422005653381
19 0.006496733520179987
20 0.006390667520463467
21 0.006293224636465311
22 0.0062002213671803474
23 0.006110009737312794
24 0.006021701730787754
25 0.005934945307672024
26 0.005849512759596109
27 0.0057654669508337975
28 0.005682558752596378
29 0.005600868724286556
30 0.005520401056855917
31 0.005441035609692335
32 0.0053628794848918915
33 0.005285775288939476
34 0.005209808703511953
35 0.005134933162480593
36 0.005061125382781029
37 0.004988380707800388
38 0.00491672195494175
39 0.0048460508696734905
40 0.004776409827172756
41 0.00470777926966548
42 0.004640108905732632
43 0.004573439247906208
44 0.00450771301984787
45 0.004442923702299595
46 0.004379057325422764
47 0.004316150210797787
48 0.004254107363522053
49 0.004192924126982689
50 0.0041326736100018024
51 0.004073282703757286
52 0.004014759790152311
53 0.003957051318138838
54 0.0039002075791358948
55 0.0038441140204668045
56 0.003788899164646864
57 0.0037344531156122684
58 0.003680775174871087
59 0.0036278674378991127
60 0.003575714770704508
61 0.0035243607126176357
62 0.003473697230219841
63 0.003423791378736496
64 0.003374570980668068
65 0.0033260590862482786
66 0.003278267802670598
67 0.003231176408007741
68 0.0031847076024860144
69 0.003138953121379018
70 0.003093830542638898
71 0.0030493782833218575
72 0.0030055600218474865
73 0.0029623594600707293
74 0.002919779857620597
75 0.0028778419364243746
76 0.002836476778611541
77 0.002795706270262599
78 0.0027555148117244244
79 0.0027159445453435183
80 0.002676892327144742
81 0.0026384363882243633
82 0.002600492676720023
83 0.0025631182361394167
84 0.002526274649426341
85 0.0024899819400161505
86 0.002454179571941495
87 0.002418922260403633
88 0.0023841557558625937
89 0.002349911257624626
90 0.002316119149327278
91 0.002282818779349327
92 0.0022500380873680115
93 0.002217694651335478
94 0.002185826888307929
95 0.00215441663749516
96 0.0021234452724456787
97 0.0020929216407239437
98 0.002062862040475011
99 0.0020332084968686104
100 0.002003985922783613
101 0.0019751866348087788
102 0.0019467804813757539
103
  • 5
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch深度回归模型是一种使用深度学习技术来实现回归分析的模型。在PyTorch中,可以使用多种深度学习技术来实现回归模型,包括神经网络、卷积神经网络、循环神经网络等。这些模型可以用于多种回归任务,如预测房价、股票价格、销售额等。 PyTorch深度回归模型的主要特点是: 1. 基于深度学习技术,能够处理大量数据和复杂模型。 2. 支持自定义网络结构,可以根据不同任务需求自由设计网络结构。 3. 支持GPU加速,能够快速处理大量数据。 4. 支持自动求导,能够自动计算梯度,简化模型训练过程。 在PyTorch中,可以使用torch.nn模块来构建深度回归模型。该模块提供了各种神经网络层和损失函数,可以用于构建不同类型的深度回归模型。下面是一个使用PyTorch实现的简单线性回归模型: ```python import torch # 构建数据集 x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]]) # 构建模型 class LinearRegression(torch.nn.Module): def __init__(self): super(LinearRegression, self).__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): out = self.linear(x) return out model = LinearRegression() # 定义损失函数和优化器 criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 训练模型 num_epochs = 1000 for epoch in range(num_epochs): inputs = x_train labels = y_train # 前向传播 outputs = model(inputs) loss = criterion(outputs, labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if (epoch+1) % 100 == 0: print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item())) # 测试模型 model.eval() with torch.no_grad(): predicted = model(x_train) print('Predicted:', predicted) ``` 在上述代码中,我们首先定义了一个简单的线性回归模型,然后使用MSE损失函数和随机梯度下降优化器进行训练。训练完成后,我们对模型进行测试,并输出预测结果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值