pytorch的buffer学习整理

pytorch模型中的buffer

这段时间忙于做项目,但是在项目中一直在模型构建中遇到buffer数据,所以花点时间整理下模型中的parameter和buffer数据的区别💕

1.torch.nn.Module.named_buffers(prefix=‘‘, recurse=True)

贴上pytorch官网对其的说明:
在这里插入图片描述
官网翻译:

named_buffers(prefix='', recurse=True)
方法: named_buffers(prefix='', recurse=True)

    Returns an iterator over module buffers, yielding both the name of the buffer as well 
    as the buffer itself.
    返回一个迭代器,该迭代器能够遍历模块的缓冲buffer,并且迭代返回的结果是缓冲的名字和缓冲本身.
    Parameters  参数
            prefix (str) – prefix to prepend to all buffer names.
            prefix (字符串) – 添加到所有缓冲名字之前的前缀.
            recurse (bool)if True, then yields buffers of this module and all submodules. 
            Otherwise, yields only buffers that are direct members of this module.
            recurse (布尔类型) – 如果该参数是True,那么表示递归地迭代返回,即迭代返回该模块的缓冲以及
            该模块的所有子模块的缓冲. 默认为True
    Yields  迭代返回
        (string, torch.Tensor) – Tuple containing the name and buffer
        (字符串,torch.Tensor类型) - 包含缓冲名字和缓冲自身的元组
        
    Example:  例子:

    >>> for name, buf in self.named_buffers():
    >>>    if name in ['running_var']:
    >>>        print(buf.size())

总结,缓冲buffer必须要登记注册才会有效,如果仅仅将张量赋值给Module模块的属性,不会被自动转为缓冲buffer.因而也无法被state_dict()、buffers()、named_buffers()访问到.此外state_dict()可以遍历缓冲buffer和参数Parameter.
可以概括为,缓冲buffer和参数Parameter的区别是前者不需要训练优化,而后者需要训练优化.在创建方法上也有区别,前者必须要将一个张量使用方法register_buffer()来登记注册,后者比较灵活,可以直接赋值给模块的属性,也可以使用方法register_parameter()来登记注册.
下面使用代码测试一下buffer数据:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
        )
        self.attribute_buffer_in = torch.randn(3,5)                       # 仅仅赋值给模型属性,是无法访问到该buffer数据
        register_buffer_in_temp = torch.randn(4,6)               
        self.register_buffer('register_buffer_in', register_buffer_in_temp)   # 注册buffer数据,才能生效,能获取到数据

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()



print('初始化之后模型修改之前'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))   
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))
for name, param in model.named_parameters():     # 访问模型的parameter参数数据的名字和其本身
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))           # 访问模型中的buffer数据本身
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))        # 访问模型中的parameter数据本身
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))        # 同时获取模型的parameter参数数据、buffer参数数据
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)



model.attribute_buffer_out = torch.randn(10,10)      # 赋值给模型属性
register_buffer_out_temp = torch.randn(15,15)
model.register_buffer('register_buffer_out', register_buffer_out_temp)  # 通过注册的方式,使得模型的buffer成员属性生效
print('模型初始化以及修改之后'.center(100,"-"))
print('调用named_buffers()'.center(100,"-"))         # 修改模型buffer属性之后,访问buffer数据名字和其本身
for name, buf in model.named_buffers():
    print(name,'-->',buf.shape)

print('调用named_parameters()'.center(100,"-"))      # 修改模型buffer属性之后,访问模型parameter数据名字和其本身
for name, param in model.named_parameters():
    print(name,'-->',param.shape)

print('调用buffers()'.center(100,"-"))
for buf in model.buffers():
    print(buf.shape)

print('调用parameters()'.center(100,"-"))
for param in model.parameters():
    print(param.shape)

print('调用state_dict()'.center(100,"-"))
for k, v in model.state_dict().items():
    print(k, '-->', v.shape)  

输出结果为:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 840 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63490' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
--------------------------------------------初始化之后模型修改之前---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])                     # 
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------模型初始化以及修改之后---------------------------------------------
-----------------------------------------调用named_buffers()------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
----------------------------------------调用named_parameters()----------------------------------------
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
--------------------------------------------调用buffers()---------------------------------------------
torch.Size([4, 6])
torch.Size([15, 15])
-------------------------------------------调用parameters()-------------------------------------------
torch.Size([64, 1, 3, 3])
torch.Size([64])
-------------------------------------------调用state_dict()-------------------------------------------
register_buffer_in --> torch.Size([4, 6])
register_buffer_out --> torch.Size([15, 15])
conv1.0.weight --> torch.Size([64, 1, 3, 3])
conv1.0.bias --> torch.Size([64])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

模型中的buffer和parameter区别

在这里插入图片描述
在这里插入图片描述
下面使用代码进行说明:
pytorch保存模型参数的一种方式为:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

可以看到模型保存的是 model.state_dict() 的返回对象。 model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

class MyModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModule, self).__init__()
        self.lin = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.lin(x)

module = MyModule(4, 2)
print(module.state_dict())

输出结果:
在这里插入图片描述
分析一个parameter和buffer的例子:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)
        self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass
model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
for buffer in model.buffers():
    print(buffer)
print("----------------")
print(model.state_dict())

输出结果:
在这里插入图片描述

在这里插入图片描述

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
pytorch学习是指使用pytorch框架中的图神经网络(Graph Neural Networks,GNN)进行图数据的学习和分析。图数据是一种表示对象之间关系的数据结构,常用于社交网络、推荐系统、化学分子结构等领域。 在pytorch学习中,我们可以通过构建具有图结构的神经网络,来对图数据进行特征提取、分类、回归等任务。Pytorch提供了一些图神经网络的库,如DGL(Deep Graph Library)和PyG(PyTorch Geometric)等,可以方便地构建和训练图神经网络。 图神经网络的核心思想是将节点和边的特征进行学习,然后利用这些学习到的特征进行下游任务。图神经网络通常由多个图卷积层(Graph Convolutional Layer)组成,每个图卷积层都会更新节点的特征表示。通过多层的图卷积层堆叠,可以逐渐扩展节点的感受野,提取更全局的特征。 在pytorch学习中,除了图卷积层,还可以使用其他类型的图神经网络层,如图注意力层(Graph Attention Layer)和图池化层(Graph Pooling Layer)等,以提升网络的性能。同时,也可以结合传统的神经网络层,如全连接层和卷积层等,来处理节点和边的特征。 在实践中,pytorch学习常用于各种图数据的任务,如节点分类、链接预测和图生成等。通过对图结构的学习,可以提取出节点和边的有用特征,从而更好地理解和处理图数据。 总而言之,pytorch学习是利用pytorch框架进行图数据学习的方法,在处理图数据时能够更好地利用图结构中的信息,并应用于各种相关任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值