源码
if os.path.isfile(model_path):
print('==> loading checkpoint {}'.format(args.resume))
checkpoint = torch.load(model_path)
start_epoch = checkpoint['epoch']
net.load_state_dict(checkpoint['net'])
else:
print('==> no checkpoint found at {}'.format(args.resume))
问题
Traceback (most recent call last):
File "train_img_model_xent.py", line 251, in <module>
main()
File "train_img_model_xent.py", line 136, in main
checkpoint = torch.load(model_path)
File "C:\Users\Liang\Anaconda3\lib\site-packages\torch\serialization.py", line 268, in load
return _load(f, map_location, pickle_module)
File "C:\Users\Liang\Anaconda3\lib\site-packages\torch\serialization.py", line 421, in _load
result = unpickler.load()
UnicodeDecodeError: 'ascii' codec can't decode byte 0xe9 in position 0: ordinal not in range(128)
原因
python版本不兼容,源码python版本低,现有环境python版本高
python3 取消了reload,在编码解码方法做了很大的调整
解决办法
按照要求打开 .\torch\serialization.py 找出错误代码段
_sys_info = pickle_module.load(f)
unpickler = pickle_module.Unpickler(f)
unpickler.persistent_load = persistent_load
result = unpickler.load()
torch库函数 serialization.py 改成
_sys_info = pickle_module.load(f, encoding='iso-8859-1') # 改动处
unpickler = pickle_module.Unpickler(f, encoding='iso-8859-1') # 改动处
unpickler.persistent_load = persistent_load
result = unpickler.load()
源码 改成
if os.path.isfile(model_path):
print('==> loading checkpoint {}'.format(args.resume))
checkpoint = torch.load(model_path)
start_epoch = checkpoint['epoch']
net.load_state_dict(checkpoint['net'],strict=False) # 改动处
else:
print('==> no checkpoint found at {}'.format(args.resume))
注意事项
如果不改源码 *.load_state_dict()会出错,如下
Traceback (most recent call last):
File "E:/wj-lab/expStad/test.py", line 79, in <module>
net.load_state_dict(checkpoint['net'])
File "D:\Anaconda3\envs\python35\lib\site-packages\torch\nn\modules\module.py", line 721, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for embed_net:
Missing key(s) in state_dict: * * *
Unexpected key(s) in state_dict: "visible_net.visible.bn1.num_batches_tracked", "visible_net.visible.layer1.0.bn1.n