问题背景
Github开源项目:https://github.com/zhang-tao-whu/e2ec
python train_net.py coco_finetune --bs 12 \
--type finetune --checkpoint data/model/model_coco.pth
报错如下:
loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
load model: data/model/model_coco.pth
Traceback (most recent call last):
File "train_net.py", line 67, in <module>
main()
File "train_net.py", line 64, in main
train(network, cfg)
File "train_net.py", line 40, in train
begin_epoch = load_network(network, model_dir=args.checkpoint, strict=False)
File "/root/autodl-tmp/e2ec/train/model_utils/utils.py", line 66, in load_network
net.load_state_dict(net_weight, strict=strict)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Network:
size mismatch for dla.ct_hm.2.weight: copying a param with shape torch.Size([80, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
size mismatch for dla.ct_hm.2.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).
由于我自己的数据集类别只有1,而COCO数据集有80个类别,预训练模型中dla.ct_hm.2
参数的大小与我的不符,所以需要舍弃预训练模型中的这个参数的权重。
解决方法
在 e2ec/train/model_utils/utils.py
中修改:
def load_network(net, model_dir, strict=True, map_location=None):
if not os.path.exists(model_dir):
print(colored('WARNING: NO MODEL LOADED !!!', 'red'))
return 0
print('load model: {}'.format(model_dir))
if map_location is None:
pretrained_model = torch.load(model_dir, map_location={'cuda:0': 'cpu', 'cuda:1': 'cpu',
'cuda:2': 'cpu', 'cuda:3': 'cpu'})
else:
pretrained_model = torch.load(model_dir, map_location=map_location)
if 'epoch' in pretrained_model.keys():
epoch = pretrained_model['epoch'] + 1
else:
epoch = 0
pretrained_model = pretrained_model['net']
net_weight = net.state_dict()
for key in net_weight.keys():
net_weight.update({key: pretrained_model[key]})
'''
舍弃部分参数
'''
net_weight.pop("dla.ct_hm.2.weight")
net_weight.pop("dla.ct_hm.2.bias")
net.load_state_dict(net_weight, strict=strict)
return epoch
注意:load_state_dict
中设置 strict=False
只对增加或删除部分层有用,对于在原来参数上改变维度大小的情况不适用。