torch.nn.Module.buffers(recurse=True)

本文详细解析了PyTorch中buffers的概念与使用方法,包括buffers与parameters的区别、buffers的作用及如何通过register_buffer进行注册。并通过代码示例展示了buffers在模型中的应用。
摘要由CSDN通过智能技术生成

参考链接: torch.nn.Module.buffers(recurse=True)
参考链接: Pytorch模型中的parameter与buffer
参考链接: Pytorch模型中的parameter与buffer
参考链接: What is the difference between register_buffer and register_parameter of nn.Module
参考链接: Pytorch模型中的parameter与buffer(torch.nn.Module的成员)

在这里插入图片描述

原文及翻译:

buffers(recurse=True)
方法: buffers(recurse=True)
    Returns an iterator over module buffers.
    返回一个迭代器,带迭代器可以遍历模块的缓冲buffer.
    Parameters 参数
        recurse (bool)if True, then yields buffers of this module and all submodules. 
        Otherwise, yields only buffers that are direct members of this module.
		recurse (布尔类型) – 如果该参数是True,表示递归地迭代返回,即迭代返回该模块的缓冲及其所有
		子模块的缓冲.否则,只返回作为该模块直接成员的缓冲.
    Yields  迭代返回
        torch.Tensor – module buffer
        torch.Tensor类型 – 该模块的缓冲
    
    Example:  例子:
    >>> for buf in model.buffers():
    >>>     print(type(buf.data), buf.size())
    <class 'torch.FloatTensor'> (20L,)
    <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
总结,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.

代码实验:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
        )
        self.attribute_buffer_in = torch.randn(3,5)
        register_buffer_in_temp = torch.randn(4,6)
        self.register_buffer('register_buffer_in', register_buffer_in_temp)

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()



print('初始化之后模型修改之前'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)



model.attribute_buffer_out = torch.randn(10,10)
register_buffer_out_temp = torch.randn(15,15)
model.register_buffer('register_buffer_out', register_buffer_out_temp)
print('模型初始化以及修改之后'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)

控制台输出结果:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 840 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
--------------------------------------------初始化之后模型修改之前---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------模型初始化以及修改之后---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
torch.Size([15, 15])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值