深度学习基础知识 register_buffer 与 register_parameter用法分析

深度学习基础知识 register_buffer 与 register_parameter用法分析

1、问题引入

思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数

import torch
import torch.nn as nn

# 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1=nn.Conv2d(in_channels= 3,
                            out_channels= 6,
                            kernel_size=3,
                            stride = 1,
                            padding=1,
                            bias=False)
        
        self.conv2=nn.Conv2d(in_channels= 6,
                            out_channels= 9,
                            kernel_size=3,
                            stride = 1,
                            padding=1,
                            bias=False)
        

        self.waight=torch.ones(10,10)
        self.bias=torch.zeros(10)

    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x = x * self.weight + self.bias
        return x
    
net=MyModule()

for name,param in net.named_parameters():  # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数
    print(name,param.shape)

print("\n","-"*40,"\n")

for key,val in net.state_dict().items():  # 说明weight与bias是不会被state_dict转化为字典中的元素的
    print(key,val.shape)

打印分析结果:
在这里插入图片描述
可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数

2、register_parameter()

register_parameter()是 torch.nn.Module 类中的一个方法

2.1 作用

1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

2.2 用法

register_parameter(name,param)

  • name:参数名称
  • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,

否则报错如下
在这里插入图片描述

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
        self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)

结果显示:
在这里插入图片描述

3、register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

  • self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

  • 参数:可以被优化器更新 (requires_grad=False / True)
  • buffer 中的数据 : 不会被优化器更新

3.2 用法

register_buffer(name,tensor)

  • name:参数名称
  • tensor:张量

代码:

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_buffer('weight', torch.ones(10, 10))   # 注意:定义的方式
        self.register_buffer('bias', torch.zeros(10))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)

效果如下所示:
在这里插入图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
这是一个实现了注意力机制的神经网络模块,主要用于处理输入序列中不同位置之间的关系。其中,dim代表输入特征的维度,num_ttokens表示输入序列的长度,num_heads表示注意力头数,qkv_bias表示是否对注意力中的查询、键、值进行偏置,qk_scale表示缩放因子,attn_drop表示注意力中的dropout率,proj_drop表示输出结果的dropout率,with_qkv表示是否需要对输入进行线性变换。 在实现中,首先根据输入的维度和头数计算每个头的维度head_dim,然后根据缩放因子scale对查询、键、值进行线性变换,得到每个头的查询、键、值向量。如果with_qkv为True,则需要对输入进行线性变换得到查询、键、值向量;否则直接使用输入作为查询、键、值向量。 接着,计算注意力分数,即将查询向量和键向量点乘并除以缩放因子scale,然后通过softmax函数得到注意力权重。将注意力权重与值向量相乘并进行加权平均,得到最终的输出结果。 另外,为了考虑不同位置之间的关系,在实现中还引入了相对位置编码。具体来说,通过计算每个位置之间的相对距离,得到一个相对位置编码矩阵,然后将其转化为一个参数relative_position_bias_table,并通过注册buffer的方式保存在模块中。在计算注意力分数时,将查询向量和键向量的相对位置编码相加,从而考虑不同位置之间的相对关系。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值