【学习参考】Pytorch 加载、查看预训练模型参数、使用部分预训练模型参数初始化网络
- 取出自己网络的参数字典
- 加载预训练网络的参数字典
- 取出预训练网络的参数字典
- 自己网络和预训练网络结构一致的层,使用预训练网络对应层的参数初始化
# 取出自己网络的参数字典
model_dict = model.state_dict()
# 加载预训练网络的参数字典
pretrained_dict = torch.load("xxxxxx.pth")
# 取出预训练网络的参数字典
keys = []
for k, v in pretrained_dict.items():
keys.append(k)
i = 0
# 自己网络和预训练网络结构一致的层,使用预训练网络对应层的参数初始化
for k, v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
model_dict[k] = pretrained_dict[keys[i]]
#print(model_dict[k])
i = i + 1
model.load_state_dict(model_dict)
2【报错】IndexError:list index out of range
检查list len()发现新创建的模型参数有898;预训练的模型参数有368 。
思路:反过来遍历!
# 取出自己网络的参数字典
model_dict = model.state_dict()
# 加载预训练网络的参数字典
pretrained_dict = torch.load("xxxxxx.pth")
# 取出预训练网络的参数字典
keys = []
for k, v in pretrained_dict.items():
keys.append(k)
i = 0
# 自己网络和预训练网络结构一致的层,使用预训练网络对应层的参数初始化
for k, v in model_dict.items():
if v.size() == model_dict[keys[i]].size():
model_dict[keys[i]]=pretrained_dict[k]
#print(model_dict[k])
i = i + 1
model.load_state_dict(model_dict)
3 其他的调用的方式
if path is not None:
self.load(path)
import torch
class BaseModel(torch.nn.Module):
# path为pt文件路径:
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device("cpu"))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)