主要目标
打印如下信息
- 网络结构
- 网络每一层变量的名字
- 网络每一层变量的具体每一个参数
因为有的时候直接整个网络打印的话(1)是会很难找(2)当参数很多的时候,有的时候会缩略起来看不到那个参数
实验代码
import torch
import torch.nn as nn
import numpy as np
import os
class NET(nn.Module):
def __init__(self):
super(NET, self).__init__()
self.net1 = nn.Linear(2,1)
def forward(self, x):
x = self.net1(x)
return x
if __name__ == '__main__':
net_test = NET()
# 打印网络中所有类内变量的信息(按照先后顺序)
print(net_test)
# 打印网络中所有类内变量参数值
print(net_test.state_dict())
# 打印网络构成的参数字典中所有的网络键值,之后根据这个键值就可以去查看固定哪一层的参数,然后通过索引甚至可以看到具体这一层的第几个参数
print(net_test.state_dict().keys())
# 通过字典键值索引打印某一个键值下面的参数
print(net_test.state_dict()["net1.bias"])
print(net_test.state_dict()["net1.bias"].shape)
代码结果
补充说明
说明1
这些方式会对所有的参数进行输出,即使没有在forward中出现也会在网络定义的时候初始化,获得内存具体可以看下面的例子,并且显示的有序字典是根据声明类内变量的顺序,而不是在forward里面运行的顺序
import torch
import torch.nn as nn
import numpy as np
import os
class NET(nn.Module):
def __init__(self):
super(NET, self).__init__()
self.net1_no_use = nn.Linear(2,1)
self.net1 = nn.Linear(2,1)
def forward(self, x):
x = self.net1(x)
return x
if __name__ == '__main__':
net_test = NET()
print(net_test)
print(net_test.state_dict())
print(net_test.state_dict().keys())
print(net_test.state_dict()["net1.bias"])
print(net_test.state_dict()["net1.bias"].shape)
说明2
如果想要打印所有的参数可以使用如下操作
import torch
import torch.nn as nn
import numpy as np
import os
class NET(nn.Module):
def __init__(self):
super(NET, self).__init__()
self.net1_no_use = nn.Linear(2,1)
self.net1 = nn.Linear(2,1)
def forward(self, x):
x = self.net1(x)
return x
if __name__ == '__main__':
net_test = NET()
print(net_test)
for name,parameters in net_test.named_parameters():
print(name,':',parameters,parameters.size())