pretrained_dict = torch.load(model_path) #从预训练文件中获取预训练权重
#修改键值名称
model_dict = model.state_dict()
new_dict = {k[15:]: v for k, v in pretrained_dict.items()}
#找相同的键值对
backbone_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.items()}
model_dict.update(backbone_dict)
model.load_state_dict(model_dict)
加载预训练模型时,出现键值/前缀不同及网络结构不同的问题
最新推荐文章于 2024-07-14 23:26:10 发布