目录
1--前言
最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;
2--问题描述
需要使用的预训练模型如下:
import torch
if __name__ == "__main__":
weights_files = './joint_model_stgcn.pt' # 权重文件路径
weights = torch.load(weights_files) # 加载权重文件
for k, v in weights.items(): # key, value
print(k) # 打印 key(参数名)
原权重文件的 key 值如下:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9
fcn.weight
fcn.bias
需求是修改以下 key 值,以适配自定义模型:
edge_importance.0 -> edge_importance0
edge_importance.1 -> edge_importance1
edge_importance.2 -> edge_importance2
edge_importance.3 -> edge_importance3
edge_importance.4 -> edge_importance4
edge_importance.5 -> edge_importance5
edge_importance.6 -> edge_importance6
edge_importance.7 -> edge_importance7
edge_importance.8 -> edge_importance8
edge_importance.9 -> edge_importance9
2--代码
基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:
import torch
import collections
if __name__ == "__main__":
# 加载原权重文件
weights_files = './joint_model_stgcn.pt'
weights = torch.load(weights_files)
# 修改
new_d = weights
for i in range(10):
new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])
# 测试
for k, v in new_d.items(): # key, value
print(k) # 打印参数名
# 保存
torch.save(new_d, 'new_joint_model_stgcn.pt')
修改后的 key 值:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9
fcn.weight
fcn.bias
3--测试
测试原权重文件和新权重文件的 value 是否相同:
import torch
if __name__ == "__main__":
origin_weights_files = './joint_model_stgcn.pt'
origin_weights = torch.load(origin_weights_files)
new_weights_files = './new_joint_model_stgcn.pt'
new_weights = torch.load(new_weights_files)
print(origin_weights['A'] == new_weights['A'])
print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])