Pytorch查看网络参数和网络内变量名称

主要目标

打印如下信息

  • 网络结构
  • 网络每一层变量的名字
  • 网络每一层变量的具体每一个参数

因为有的时候直接整个网络打印的话(1)是会很难找(2)当参数很多的时候,有的时候会缩略起来看不到那个参数

实验代码

import torch 
import torch.nn as nn
import numpy as np
import os

class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.net1 = nn.Linear(2,1)

    def forward(self, x):
        x = self.net1(x)
        return x


if __name__ == '__main__':
    net_test = NET()
    # 打印网络中所有类内变量的信息(按照先后顺序)
    print(net_test)
    # 打印网络中所有类内变量参数值
    print(net_test.state_dict())
    # 打印网络构成的参数字典中所有的网络键值,之后根据这个键值就可以去查看固定哪一层的参数,然后通过索引甚至可以看到具体这一层的第几个参数
    print(net_test.state_dict().keys())
    # 通过字典键值索引打印某一个键值下面的参数
    print(net_test.state_dict()["net1.bias"])
    print(net_test.state_dict()["net1.bias"].shape)

代码结果

image-20210817210805148

补充说明

说明1

这些方式会对所有的参数进行输出,即使没有在forward中出现也会在网络定义的时候初始化,获得内存具体可以看下面的例子,并且显示的有序字典是根据声明类内变量的顺序,而不是在forward里面运行的顺序

import torch 
import torch.nn as nn
import numpy as np
import os

class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.net1_no_use = nn.Linear(2,1)
        self.net1 = nn.Linear(2,1)

    def forward(self, x):
        x = self.net1(x)
        return x

if __name__ == '__main__':
    net_test = NET()
    print(net_test)
    print(net_test.state_dict())
    print(net_test.state_dict().keys())
    print(net_test.state_dict()["net1.bias"])
    print(net_test.state_dict()["net1.bias"].shape)

image-20210817211055543

说明2

如果想要打印所有的参数可以使用如下操作

import torch 
import torch.nn as nn
import numpy as np
import os

class NET(nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.net1_no_use = nn.Linear(2,1)
        self.net1 = nn.Linear(2,1)

    def forward(self, x):
        x = self.net1(x)
        return x

if __name__ == '__main__':
    net_test = NET()
    print(net_test)
    for name,parameters in net_test.named_parameters():
        print(name,':',parameters,parameters.size())

image-20210817211450171

LAST 参考文献

Pytorch 查看模型参数_happyday_d的博客-CSDN博客_pytorch查看模型参数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值