pytorch获取全部权重参数、每一层权重参数
首先需要安装torchsummary
在相应的虚拟环境下pip install torchsummary
1、打印每层参数信息:
summary(net,input_size,batch_size,device),
net:网络模型
input_size:网络输入图片的shape
batch_size:默认参数为-1
device:在gpu上还是cpu上运行,默认是cuda在gpu上运行,若想在cpu上运行,需将参数改为cpu。
eg.vgg16网络:
from models import VGG16_torch
model = vgg16()
summary(model,(3,32,32),device=‘cpu’)
2、根据需要,输出相应层的权重
首先查看每层对应的名称
model = vgg16()
for name in model.state_dict():
print(name)
再根据名称输出相应层的权重
print(model.state_dict()['layers.0.conv2d.weight'])
3、打印模块名字和参数大小
for name, parameters in model.named_parameters():
print(name, ';', parameters.size())
输出结果:
4、加载模型全部参数
import torch
y = torch.load('vgg16_baseline.t7')
print(y)