1、打印网络每层输出的大小
import torch
import torch.nn as nn
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 10)
def forward(self, x):
x1 = torch.relu(self.fc1(x))
x2 = self.fc2(x1)
return x1, x2
model = Net()
summary(model, input_size=(3, 20, 10))
# print(model)
2、输出每个参数矩阵的名称(nn.Module里面关于参数有两个很重要的属性named_parameters()和parameters(),前者给出网络层的名字和参数的迭代器,而后者仅仅是参数的迭代器。)
for param in model.named_parameters():
print(param[0]) # 输出参数名称
3、输出具体的网络参数值
paras = list(model.parameters())
for num,para in enumerate(paras):
print('number:',num)
print(para)
print('_____________________________')
<