def load_weights(self, base_file):
#pretrain_dict = model_zoo.load_url(model_url)
print('Loading weights into state dict...')
pretrain_dict = torch.load(base_file, map_location=torch.device('cpu'))
model_dict = self.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
self.load_state_dict(model_dict)
print('Finished!')
def load_weight(new_model, model_dir):
#当前网络权重字典
orginal_dict = new_model.state_dict()
#读取的网络权重字典
weight_dict = torch.load(model_dir, map_location=torch.device('cpu'))
for key, value in orginal_dict.items():
for key2,vlaue2 in weight_dict.items():
if key==key2 and value.size() == vlaue2.size():
print("key对应且形状相同!")
orginal_dict[key] = weight_dict[key2]
new_model.load_state_dict(orginal_dict)
print('load model Finished!')
按key读取已有的模型参数,继续训练,在网络结构略作修改的时候可以使用
加载模型,简单复现