Pytorch中nn.Module中的self.register_buffer解释

self.register_buffer作用解释

今天遇到了这样一种用法,self.register_buffer(‘name’,Tensor),该方法的作用在于定义一组参数。该组参数在模型训练时不会更新(即调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值),但是该组参数又作为模型参数不可或缺的一部分。

实验

四种方式初始化模型中的参数

  1. 定义常见模型时的操作
  2. 使用register_buffer()定义一组参数
  3. 使用register_parameter()定义一组参数
  4. 使用python类的属性方式定义一组变量
import torch
import torch.nn as nn
from collections import OrderedDict

class Model(nn.Module):
	def __init__(self):
	super(Model,self).__init__()
	
	#(1)定义常见模型时的操作
	self.param_nn = nn.Sequential(OrderedDict([
		('conv',nn.Conv2d(1,1,3,bias=False)),
		('fc',nn.Linear(1,2,bias=False))
	]))
	
	#(2)使用register_buffer()定义一组参数
	self.register_buffer('reg_buf',torch.randn(1,2))

	#(3)使用register_parameter()定义一组参数
	self.register_parameter('reg_param',nn.Parameter(torch.randn(1,2)))

	#(4)使用python类的属性方式定义一组变量
	self.param_attr = torch.randn(1,2)

net = Model()
	

问题1:哪些参数会在模型训练时被更新?

因为定义优化器时会传入一个参数net.parameters,所以在模型训练时更新的参数可以通过list(net.named_parameters())查看
在这里插入图片描述
结果说明,只有方式(1)和方式(3)定义的参数可以被更新

问题2:模型中的参数到底有哪些?

模型中的所有参数都装在state_dict()中,所以可以通过net.state_dict()方式查看
在这里插入图片描述
结果说明,只有方式(4)的参数不在模型的参数列表,没有被模型训练时更新的参数reg_buf,依然在模型的参数列表里

self.register_buffer()的使用方法

  1. 传入参数:第一个参数传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数
  2. 在模型定义中调用:使用self.name方法,本例中就是self. reg_buf
  3. 在实例化模型后调用:使用net.buffers()方法。

其他知识

实际上,Pytorch定义的模型用OrderedDict()方式记录这三种类型,分别保存在self._modules, self._parameters 和self.buffer三个私有属性中

在模型实例化后可以用以下方法看三个私有属性中的变量
net.modules()
net.parameters()
net.buffers()

self._parameters 和net.parameters() 的返回值并不相同,self._parameters只记录了使用self.register_parameter()定义的参数,而net.parameters()返回所有可学习参数。

参考:
[1]Pytorchnn.Module中的self.register_buffer()解析

  • 27
    点赞
  • 63
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值