2.4 Pytorch基础模型组件及线性回归

本文介绍了Pytorch中模型构建的基础组件,包括nn.Module、nn.Sequential、优化器类和损失函数的使用。通过实例展示了如何用Pytorch实现线性回归模型,并讲解了如何在GPU上运行代码。此外,还详细介绍了多种优化算法,如梯度下降、SGD、MBGD、动量法、AdaGrad、RMSProp和Adam。最后探讨了深度学习中防止过拟合的策略,如Dropout和Batch Normalization。
摘要由CSDN通过智能技术生成

Pytorch基础模型组件

目标

  1. 知道Pytorch中Module的使用方法

  2. 知道Pytorch中优化器类的使用方法

  3. 知道Pytorch中常见的损失函数的使用方法

  4. 知道如何在GPU上运行代码

  5. 能够说出常见的优化器及其原理

1. Pytorch完成模型常用API

在前一部分,我们自己实现了通过torch的相关方法完成反向传播和参数更新,在pytorch中预设了一些更加灵活简单的对象,让我们来构造模型、定义损失,优化损失等

那么接下来,我们一起来了解一下其中常用的API

1.1 nn.Module

nn.Module 是torch.nn提供的一个类,是pytorch中我们自定义网络的一个基类,在这个类中定义了很多有用的方法,让我们在继承这个类定义网络的时候非常简单

当我们自定义网络的时候,有两个方法需要特别注意:

  1. __init__需要调用super方法,继承父类的属性和方法

  2. forward方法必须实现,用来定义我们的网络的向前计算的过程

用前面的y = wx+b的模型举例如下:

from torch import nn

class Lr(nn.Module):

    def __init__(self):

        super(Lr, self).__init__() #继承父类init的参数

        self.linear = nn.Linear(1, 1) # 声明网络中的组件



    def forward(self, x):

        out = self.linear(x)

        return out

注意:

  1. nn.Linear为torch预定义好的线性模型,也被称为全链接层,传入的参数为输入的数量,输出的数量(in_features, out_features),是不算(batch_size的列数)

  2. nn.Module定义了__call__方法,实现的就是调用forward方法,即Lr的实例,能够直接被传入参数调用,实际上调用的是forward方法并传入参数

# 实例化模型

model = Lr()

# 传入数据,计算结果

predict = model(x)

1.2 nn.Sequential

如果模型结构比较简单,在forward函数中没有很复杂的操作。这时可以用nn.Sequential来构建模型,nn.Sequential会自动完成forward函数的创建.

In [163]: model = nn.Sequential(nn.Linear(2,64), nn.Linear(64, 1))



In [164]: x = torch.randn(10,2) # 10个样本,2个特征



In [165]: model(x)

Out[165]:

tensor([[-0.3507],

[-0.3708],

[-0.4118],

[-0.2604],

[-0.4318],

[-0.3503],

[-0.4953],

[-0.5464],

[-0.5273],

[-0.4542]], grad_fn=<AddmmBackward>)

1.3 优化器类

优化器(optimizer),可以理解为torch为我们封装的用来进行更新参数的方法,比如常见的随机梯度下降(stochastic gradient descent,SGD)

优化器类都是由torch.optim提供的,例如

  1. torch.optim.SGD(参数,学习率)

  2. torch.optim.Adam(参数,学习率)

注意:

  1. 参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值