Pytorch buffer(register_buffer)

      回顾模型保存:torch.save(model.state_dict()),其中model.state_dict()是一个字典,里边存着我们模型各个部分的参数。在model中,我们需要更新其中的参数,训练结束将参数保存下来。但在某些时候,我们可能希望模型中的某些参数不更新(从开始到结束均保持不变),但又希望参数保存下来(model.state_dict() ),例如BN的running_mean和running_var, 它们在训练过程中是不需要训练的,但是需要在整个训练过程中维护这个mean和var值。这时我们就会用到 register_buffer()

模型中需要保存下来的参数包括两种:

  • 一种是反向传播需要被optimizer更新的,称之为 parameter
  • 一种是反向传播不需要被optimizer更新,称之为 buffer

第一种参数我们可以通过 model.parameters() 返回;第二种参数我们可以通过 model.buffers() 返回。

model.state_dict()会同时包含parameter和buffer

但不一定每个网络都会有buffer,像nn.Linear()就没有

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.register_buffer('test',torch.rand(input_size, output_size))
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)

model = MyModule(4, 2)
print(list(model.buffers()))
print(list(model.named_buffers()))

输出model.state_dict()会包含buffer的

import torch
from torch import nn
 
class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.register_buffer('test',torch.rand(input_size, output_size))
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)
 
model = MyModule(4, 2)
print(model.state_dict())

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值