1. 基本知识
register_buffer 是 PyTorch 中 torch.nn.Module 提供的一个方法,允许用户将某些张量注册为模块的一部分,但不会被视为可训练参数。这些张量会随模型保存和加载,但在反向传播过程中不会更新
register_buffer 的作用:
- 将张量注册为模型的缓冲区(buffer),意味着这些张量会与模型一起保存和加载
- 与参数不同,缓冲区不会参与梯度计算,因此不会在训练时更新
- 常用于存储像均值、方差、掩码或其他状态信息
与模型参数的区别:
- 模型参数(
register_parameter
或self.param = nn.Parameter(tensor)
):这些张量会被认为是可学习的权重,在训练过程中会被优化器更新 - 缓冲区(
register_buffer
):这些张量不会被优化器更新,适合用于保存模型的常量或中间状态
使用的场景有如下:
- 存储一些与训练无关但随模型保存的常量
- 存储一些在 eval() 模式下需要使用的统计数据(如 BatchNorm 层中的均值和方差)
- 缓存计算中间的状态或掩码,不希望它们在训练中被更新
2. Demo
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 注册一个缓冲区, 不参与训练
self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))
def forward(self, x):
# 使用缓冲区中的张量
return x + self.constant_tensor
model = MyModel()
print(model.constant_tensor) # 输出缓冲区内容 tensor([1., 2., 3.])
在训练模式和评估模式中的差异
model.train() # 训练模式
print(model(torch.tensor([1.0, 1.0, 1.0]))) # tensor([2.0, 3.0, 4.0])
model.eval() # 评估模式
print(model(torch.tensor([1.0, 1.0, 1.0]))) # tensor([2.0, 3.0, 4.0])
缓冲区内容不会因为模型模式的切换(train() 或 eval())而改变,因为缓冲区不是可训练参数,它仅存储数据
示例 1: 存储中间计算的掩码
class MaskedModel(nn.Module):
def __init__(self):
super(MaskedModel, self).__init__()
self.register_buffer('mask', torch.tensor([1.0, 0.0, 1.0]))
def forward(self, x):
# 应用掩码到输入上
return x * self.mask
model = MaskedModel()
input_tensor = torch.tensor([4.0, 3.0, 2.0])
output = model(input_tensor)
print(output) # tensor([4.0, 0.0, 2.0])
示例 2: 保存 BatchNorm 层的均值和方差
在批归一化层(BatchNorm)中,均值和方差是通过 register_buffer 来存储的
这些值在训练时会动态更新,但在推理时会固定使用
class MyBatchNorm(nn.Module):
def __init__(self, num_features):
super(MyBatchNorm, self).__init__()
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
# 这里简单模拟了批归一化的逻辑
return (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)
bn_layer = MyBatchNorm(3)
input_tensor = torch.tensor([2.0, 3.0, 4.0])
output = bn_layer(input_tensor)
print(output)
示例 3: 用于固定的均值和方差
不希望在训练时更新某些统计数据(比如均值和方差),而希望使用固定的值
class FixedNorm(nn.Module):
def __init__(self):
super(FixedNorm, self).__init__()
# 注册固定的均值和方差
self.register_buffer('mean', torch.tensor([0.5]))
self.register_buffer('std', torch.tensor([0.25]))
def forward(self, x):
return (x - self.mean) / self.std
model = FixedNorm()
input_tensor = torch.tensor([1.0, 0.5, 0.0])
output = model(input_tensor)
print(output) # tensor([ 2.0, 0.0, -2.0])
3. 与自动注册的差异
补充与上述不同的知识点,上述为手动注册,下面为自动注册
3.1 torch.nn.Parameter
这些 Parameter 会参与反向传播和梯度更新。
通过 model.parameters() 可以获得所有自动注册的参数
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# torch.nn.Parameter 会自动注册为模型参数
self.weight = nn.Parameter(torch.randn(3, 3))
def forward(self, x):
return x @ self.weight
model = MyModel()
# weight 自动注册为参数
print(list(model.parameters())) # 输出: [Parameter containing: tensor(...)]
3.2 自动注册子模块
在模型的 __init__
函数中定义了 torch.nn.Module 子模块(例如卷积层、线性层等),PyTorch 会自动将这些子模块注册到模型中,并且它们的参数也会一并注册
- 通过 model.children() 或 model.named_children() 获取所有子模块
- 子模块的所有参数也会自动注册,可以通过 model.parameters() 或 model.named_parameters() 获取
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# nn.Module 子模块会自动注册
self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
def forward(self, x):
return self.conv(x)
model = MyModel()
# conv 层自动注册为模型的子模块
print(list(model.children())) # 输出: [Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))]
3.3 总结
register_buffer
:用于注册不需要梯度的张量,比如存储模型状态的变量,不参与反向传播和优化过程register_parameter
:用于注册模型的可训练参数,会参与梯度计算和优化过程
两者结合的Demo
import torch
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 自动注册的子模块
self.fc = nn.Linear(5, 5)
# 手动注册一个可训练的参数
self.register_parameter('my_weight', nn.Parameter(torch.randn(5, 5)))
# 手动注册一个不可训练的缓冲区
self.register_buffer('my_buffer', torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]))
def forward(self, x):
x = self.fc(x)
return x @ self.my_weight + self.my_buffer
model = MyModel()
# 输出所有注册的参数和缓冲区
print("Parameters:", list(model.named_parameters()))
print("Buffers:", list(model.named_buffers()))