详细分析Pytorch中的register_buffer基本知识(附Demo)

1. 基本知识

register_buffer 是 PyTorch 中 torch.nn.Module 提供的一个方法,允许用户将某些张量注册为模块的一部分,但不会被视为可训练参数。这些张量会随模型保存和加载,但在反向传播过程中不会更新

register_buffer 的作用

  • 将张量注册为模型的缓冲区(buffer),意味着这些张量会与模型一起保存和加载
  • 与参数不同,缓冲区不会参与梯度计算,因此不会在训练时更新
  • 常用于存储像均值、方差、掩码或其他状态信息

与模型参数的区别

  • 模型参数(register_parameterself.param = nn.Parameter(tensor)):这些张量会被认为是可学习的权重,在训练过程中会被优化器更新
  • 缓冲区(register_buffer):这些张量不会被优化器更新,适合用于保存模型的常量或中间状态

使用的场景有如下:

  1. 存储一些与训练无关但随模型保存的常量
  2. 存储一些在 eval() 模式下需要使用的统计数据(如 BatchNorm 层中的均值和方差)
  3. 缓存计算中间的状态或掩码,不希望它们在训练中被更新

2. Demo

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 注册一个缓冲区, 不参与训练
        self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))
    
    def forward(self, x):
        # 使用缓冲区中的张量
        return x + self.constant_tensor

model = MyModel()
print(model.constant_tensor)  # 输出缓冲区内容 tensor([1., 2., 3.])

在训练模式和评估模式中的差异

model.train()  # 训练模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

model.eval()  # 评估模式
print(model(torch.tensor([1.0, 1.0, 1.0])))  # tensor([2.0, 3.0, 4.0])

缓冲区内容不会因为模型模式的切换(train() 或 eval())而改变,因为缓冲区不是可训练参数,它仅存储数据

示例 1: 存储中间计算的掩码

class MaskedModel(nn.Module):
    def __init__(self):
        super(MaskedModel, self).__init__()
        self.register_buffer('mask', torch.tensor([1.0, 0.0, 1.0]))
    
    def forward(self, x):
        # 应用掩码到输入上
        return x * self.mask

model = MaskedModel()
input_tensor = torch.tensor([4.0, 3.0, 2.0])
output = model(input_tensor)
print(output)  # tensor([4.0, 0.0, 2.0])

示例 2: 保存 BatchNorm 层的均值和方差
在批归一化层(BatchNorm)中,均值和方差是通过 register_buffer 来存储的
这些值在训练时会动态更新,但在推理时会固定使用

class MyBatchNorm(nn.Module):
    def __init__(self, num_features):
        super(MyBatchNorm, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    
    def forward(self, x):
        # 这里简单模拟了批归一化的逻辑
        return (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)

bn_layer = MyBatchNorm(3)
input_tensor = torch.tensor([2.0, 3.0, 4.0])
output = bn_layer(input_tensor)
print(output)

示例 3: 用于固定的均值和方差
不希望在训练时更新某些统计数据(比如均值和方差),而希望使用固定的值

class FixedNorm(nn.Module):
    def __init__(self):
        super(FixedNorm, self).__init__()
        # 注册固定的均值和方差
        self.register_buffer('mean', torch.tensor([0.5]))
        self.register_buffer('std', torch.tensor([0.25]))
    
    def forward(self, x):
        return (x - self.mean) / self.std

model = FixedNorm()
input_tensor = torch.tensor([1.0, 0.5, 0.0])
output = model(input_tensor)
print(output)  # tensor([ 2.0, 0.0, -2.0])

3. 与自动注册的差异

补充与上述不同的知识点,上述为手动注册,下面为自动注册

3.1 torch.nn.Parameter

这些 Parameter 会参与反向传播和梯度更新。
通过 model.parameters() 可以获得所有自动注册的参数

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # torch.nn.Parameter 会自动注册为模型参数
        self.weight = nn.Parameter(torch.randn(3, 3))
    
    def forward(self, x):
        return x @ self.weight

model = MyModel()
# weight 自动注册为参数
print(list(model.parameters()))  # 输出: [Parameter containing: tensor(...)]

3.2 自动注册子模块

在模型的 __init__ 函数中定义了 torch.nn.Module 子模块(例如卷积层、线性层等),PyTorch 会自动将这些子模块注册到模型中,并且它们的参数也会一并注册

  • 通过 model.children() 或 model.named_children() 获取所有子模块
  • 子模块的所有参数也会自动注册,可以通过 model.parameters() 或 model.named_parameters() 获取
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # nn.Module 子模块会自动注册
        self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
    
    def forward(self, x):
        return self.conv(x)

model = MyModel()
# conv 层自动注册为模型的子模块
print(list(model.children()))  # 输出: [Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))]

3.3 总结

  • register_buffer:用于注册不需要梯度的张量,比如存储模型状态的变量,不参与反向传播和优化过程
  • register_parameter:用于注册模型的可训练参数,会参与梯度计算和优化过程

两者结合的Demo

import torch
from torch import nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 自动注册的子模块
        self.fc = nn.Linear(5, 5)

        # 手动注册一个可训练的参数
        self.register_parameter('my_weight', nn.Parameter(torch.randn(5, 5)))

        # 手动注册一个不可训练的缓冲区
        self.register_buffer('my_buffer', torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))

    def forward(self, x):
        x = self.fc(x)
        return x @ self.my_weight + self.my_buffer


model = MyModel()
# 输出所有注册的参数和缓冲区
print("Parameters:", list(model.named_parameters()))
print("Buffers:", list(model.named_buffers()))

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农研究僧

你的鼓励将是我创作的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值