pytorch中的模块简介

1. 模块类

pytorch可以通过继承模块类来自定义模型的实例化,对内部定义的模块进行实例化,再通过前向计算调佣子模块,从而完成深度学习模型的搭建。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self, ...):  # 模块初始化,...为用户的输入参数
        super(Model, self).__init__()  # 继承父类的方法
        ...  # 根据传入的参数来定义子模块
        
    def forward(self, ...):
        # 定义前向计算的输入参数,一般是张量或者其他的参数
        ret = ...  # 根据传入的张量和子模块计算返回张量
        return ret

通过__init__方法初始化整个模型,再使用forward方法对该模块进行前向计算。在使用__init__方法的时候,可以在类内初始化子模块,然后在forward方法中调用这些初始化的子模块,最后输出张量。

2. 基于模块类的简单线性回归类

2.1 pytorch线性回归模型实例
import torch
import torch.nn as nn

class LinearModel(nn.Module):
    def __init__(self, ndim):
        super(LinearModel, self).__init__()
        self.ndim = ndim
        
        self.weight = nn.Parameter(torch.randn(ndim, 1))  # 定义权重
        self.bias = nn.Parameter(torch.randn(1))  # 定义偏置
        
    def forward(self, x):
        # y = Wx + b
        return x.mm(self.weight) + self.bias

为了构造线性变换,我们需要知道输入特征维度的大小、线性回归的权重(self.weight)和偏置(self.bias)。在forward方法中输入一个特征向量 x x x(大小为迷你批次大小x特征维度大小),做线性变换(使用mm方法做矩阵乘法线性变换),加偏置的值,最后输入一个预测值。nn.Parameter包装参数使之成为子模块(仅有参数构造的子模块),方便后续参数优化。

2.2 pytorch线性回归模型调用方法实例
lm = LinearModel(5)  # 模型实例化,特征数为5
x = torch.randn(4, 5)  # 随机输入,迷你批次为4
lm(x)  # 每个迷你批次的输出

在这里插入图片描述

3. 线性回归类的实例化和方法调用

3.1 使用named_parameters方法和parameters方法获取模型的参数。

这两个方法都是返回生成器,named_parameters得到的是该模型所有参数的名称和对应的张量值,而parameters方法返回该模型的所有参数对应的张量值。

3.2 使用train和eval方法进行模型训练和测试状态的转换

在模型训练中,有些子模块有两种状态,及训练状态和预测状态,pytorch的模型经常需要在两种状态中相互转换。调用train方法会把模块(所有子模块)转换到训练状态,调用eval方法会把模块(所有子模块)转换到预测状态。pytorch的模型在不同的状态下的预测准确率会有差异,在训练模型的时候需要转换为训练状态,在预测的时候需要转换为预测状态,否则最后的预测准确率可能会降低。

3.3 使用named_buffers方法和buffers方法获取张量的缓存

除了通过反向传播得到梯度来进行训练的参数外,还有一些参数并不参与梯度传播,但是会在训练中得到更新,这种参数称为缓存(Buffer),具体的例子包括批次归一化层的平均值(Mean)和方差(Variance)。在模块中调用register_buffer方法可以在模块中加入这种类型的张量,并使用named_buffers可以获得缓存的名字和缓存的张量的值组成的生成器,通过buffers可以获取缓存张量值组成的生成器。

3.4 使用named_children方法和children方法获取模型的子模块

有时候需要对模块的子模块进行迭代,这时就需要使用named_children方法和children方法来获取子模块名字、子模块的生成器,以及只有子模块的生成器。如果要获取模块中所有模块的信息,可以使用named_modules和modules来(递归)得到相关信息。

3.5 使用apply方法递归地对子模块进行函数应用

如果需要对pytorch所有的模块应用到一个函数,可以使用apply方法,通过传入一个函数或者匿名函数来递归地应用这些函数,传入的函数一模块作为参数,在函数内部对模块进行修改。

3.6 改变模块参数数据类型和存储的位置

在深度学习中可以改变模块的参数所在设备(CPU或者GPU)。若要改变参数的数据类型,可以通过to方法加上需要转变的目标数据类型来实现,float方法会转换所有的参数为单精度浮点数,half方法会转换所有的参数

4. pytorch模块方法调用实例

lm = LinearModel(5)
x = torch.randn(4, 5)  # 模型输入
print(lm(x))  # 模型获取对应的输出
print(lm.named_parameters())  # 获取模型参数(带名字)的生成器
print(list(lm.named_parameters()))  # 转换生成器为列表
print('=' * 50)
print(lm.parameters())  # 获取模型参数(不带名字)的生长器
print(list(lm.parameters()))  # 转换生成器为列表
lm.half()  # 转换模型参数为半精度浮点数
print(lm.parameters())  # 显示模型参数,可以看到已经转换为半精度浮点数

在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

饕餮&化骨龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值