base_trainer.py
if check_point_net_path!=None:
checkpoint = loading.torch_load_legacy(check_point_net_path)
model_dict = net.state_dict()
#print("model_dict",model_dict)
#pretrained_dict = {k:v for k,v in checkpoint['net'].items() if k in model_dict }
pretrained_dict = {}
for k,v in checkpoint['net'].items():
if k in model_dict:
if(v.shape == model_dict[k].shape):
pretrained_dict[k]=v
print(k)
model_dict.update(pretrained_dict)
#net.load_state_dict(checkpoint_dict['net'])
net.load_state_dict(model_dict)
print("pretrain loading success!")