2021-07-14

神经网络训练权重的提取

前言

基于pytorch搭建神经网络后,通过训练过程得到权重。权重数据有时需要保存在嵌入式设备的bram中,本文主要介绍将通过软件计算得到的权重以及偏置数据进行提取,以及将它们整理成合适的形式放入文本文件中。之后利用软件控制程序(例如SDK等)在硬件初始化过程中将权重等数据输入到bram中,用于硬件的识别功能的实现:硬件通过自带的外设(如麦克风,摄像设备等)获得外界输入,之后经过专用硬件电路实现输入的预处理,之后将预处理得到的数据输入到相应预测的硬件电路中,利用软件训练出的权重及偏置实现预测。 因此,本文主要介绍从pycharm中提取寻训练好的权重及偏置,并进行整理。

一、权重提取——利用Module named_parameters

1. Module named_parameters介绍

named_parameters 能够获取到所有的参数。类中的成员是私有的,通过这种方式能够获取到所有的参数,可以获得权重或偏置的命名以及具体的数据值。

相应的代码简要表示如下:

class Net(nn.Module):

    def __init__(self, device):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(96, 96)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 12)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

for name, param in Net(device).named_parameters():
        print(f"name: {name}, param: {param}")

之后就可以打印出在Net类中定义的所有网络层中的权重以及偏置的名字与权重。

对于具体的打印形式由print定义的方式决定。

可以参考:
链接: https://www.jianshu.com/p/bb88f7c08022.

2. 对named_parameters的巧用(获取某一层权重)

如果想要获取某一层的权重,可以直接使用named_parameters得到某一层的权重,如:

resnet18 = models.resnet(pretrained=True)
for name,parameters in resnet.named_parameters():
    print(name,':',parameters.size())

另一种方法是,如果想要获取的权重或偏置的层数过多而且并不连续,如果对每层使用named_parameters进行提取的话,可能较为繁琐,可以考虑的方法是:

  1. 直接获取整个models的所有权重及偏置
  2. 将权重以及偏置以字典的形式进行存储,键值对中,键对应权重名称,值对应numpy数组的形式

将名称对应的权重值转换成数组的形式,可以方便对权重或偏置数据的拆分,对于嵌入式设备来说,可能需要对软件计算出的权重进行拆分,然后进行分类输入到嵌入式设备中。如对于RNN中的GRU,得到的权重是ih_weight_l[k],hh_weight_l[k]等,其分别可以拆分为w_ir|w_iz|w_inw_hr|w_hz|w_hn

实现上述操作的代码为:

param={} # 初始化字典
for name,parameters in models.named_parameters():
    print(name,':',parameters.size())   #输出权重及偏置的尺寸
    param[name]=parameters.detach().numpy() #将张量转换成numpy数组的形式

以写的深度可分离卷积网络的权重及偏置为例
输出结果为:

fc1.weight : torch.Size([96, 96])
fc1.bias : torch.Size([96])
fc2.weight : torch.Size([64, 64])
fc2.bias : torch.Size([64])
fc3.weight : torch.Size([12, 64])
fc3.bias : torch.Size([12])
depth_conv.weight : torch.Size([1, 1, 1, 3])
point_conv.weight : torch.Size([12, 1, 1, 1])
point_conv.bias : torch.Size([12])

3. 利用得到的字典输出想要的数组

在经过2.的步骤得到相应的字典后,利用权重的名称得到对应的数据的数组

例如想要得到depth_conv的权重:

    # depth_conv.weight : torch.Size([1, 1, 1, 3])
    dw = []   #初始化dw数组,用于存放权重数据
    dw = parm['depth_conv.weight'][0,0,0,:]
    dw = dw.reshape(-1,1)
    print(dw)

这里进行reshape变换矩阵尺寸的目的是获得方便之后读入嵌入式设备中

打印结果为:

[[ 0.08452398]
 [ 0.47839177]
 [-0.20986146]]

4. 对于打印设置的特别说明

对于打印的设置,非常重要,设置合适的打印格式后,对于数据的整理事半功倍。
对于打印的设置的方法是:

torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None)

相关参数的解释:

  • precision: 显示浮点张量元素的精度,指的是小数点后的小数位数,默认是4位
  • threshold: 设置特定张量中显示的数据的个数,如果数据量超过该设定值,就会发生折叠,只显示前几行(列)和后几行(列),默认值是1000
  • edgeitems: 针对数据省略的情况,设置省略时前后显示的行数,默认为3

例如:

[[ 0.01917863]
 [ 0.09933621]
 [ 0.08588736]
 ...
 [-0.04827718]
 [ 0.05256374]
 [ 0.12050059]]
[[ 0.08901154]
 [ 0.04171528]
 [-0.00469625]
 ...
 [-0.05011886]
 [ 0.06542964]
 [-0.06165306]]
  • linewidth: 显示每行可以打印的最大字符数,注意事项
    1. 这里指的是字符个数,而不是数据个数,是包括"[]",字母等等在内的字符个数,但一定会保证显示完整的数据,不会发生一个数据在一行显示一部分,在下一行显示另一部分的情况
    1. 这里说的是最大字符数,是指若当前行字符超过设置的值时,会自动添加换行符,到下一行继续打印。不是说该行必须显示设置的字符数。这个设置有极大的方便,因为对于特定的权重矩阵的格式是由网络参数决定的,如果按照linewidth的默认值来看,打印结果会打乱原有的矩阵尺寸,对矩阵的拆分以及观察造成干扰,而将linewidth设置为极大值可以保证每行的打印按照原有的矩阵尺寸格式来打印
  • profile: 简便设置显示选项,分别有default,short,full满足显示要求
  • sci_mode: 显示是否使用科学计数法,设置位False即为不适用科学计数法

set_printoptions的具体使用示例及其他参数说明可以参考:
链接: https://blog.csdn.net/Fluid_ray/article/details/109556867.
链接: https://blog.csdn.net/Corollary/article/details/105920322.

总结

本节主要介绍了对于神经网络训练出的权重的提取工作,讲述了得到各层数据的方法,打印设置的小技巧。

参考素材:

  1. 链接: https://www.jianshu.com/p/bb88f7c08022.
  2. 链接: https://blog.csdn.net/Fluid_ray/article/details/109556867.
  3. 链接: https://blog.csdn.net/Corollary/article/details/105920322.
  4. 链接: https://blog.csdn.net/happyday_d/article/details/88974361.
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值