@Date : 2022/11/10 11:25
@Author : ZZJin
@File : vgg_weight.py.py
@Location :
@Description : 提取VGG16网络中的权重参数ZQM
import numpy as np
import torch.nn as nn
import torchvision
model = torchvision.models.vgg16(pretrained=True)
# print('***********************模型结构为***********************')
# print(model)
feature_map = dict()
layer_num = 0
for k, m in enumerate(model.modules()): # k是第几层 m有具体的层参数
if isinstance(m, nn.Conv2d):
layer_num = layer_num + 1
weight_copy = m.weight.data.clone() # 提取权重
name = str(layer_num) + 'Conv2d_Weight'
feature_map[name] = np.array(weight_copy.detach().numpy())
weight_copy = m.bias.data.clone() # 提取偏置
name = str(layer_num) + 'Conv2d_Bias'
feature_map[name] = np.array(weight_copy.detach().numpy())
elif isinstance(m, nn.Linear):
layer_num = layer_num + 1
weight_copy = m.weight.data.clone() # 提取权重
name =str(layer_num) + 'Linear_Weight'
feature_map[name] = np.array(weight_copy.detach().numpy())
weight_copy = m.bias.data.clone() # 提取偏置
name = str(layer_num) + 'Linear_Bias'
feature_map[name] = np.array(weight_copy.detach().numpy())
np.save('feature_map.npy', feature_map) # 注意带上后缀名