resnet = models.resnet50(pretrained=True)
new_state_dict = resnet.state_dict()
dd = net.state_dict() #net是自己定义的含有resnet backbone的模型
for k in new_state_dict.keys():
print(k)
if k in dd.keys() and not k.startswith('fc'): #不使用全连接的参数
print('yes')
dd[k] = new_state_dict[k]
net.load_state_dict(dd)
pytorch加载预训练模型部分参数
最新推荐文章于 2022-04-06 11:00:44 发布