错误原因
- 使用了 torch.load() 加载模型
解决办法
-
使用 model.load_state_dict(torch.load(./model.pt,map_location=‘cpu’),strict=False)
-
map_location:当模型训练的时候用的 gpu 但是加载的时候用的是 cpu 环境,这个时候要进行映射
-
strict=False 否则容易报错:Unexpected key(s) in state_dict: “lstm.weight_ih_l3”, “lstm.weight_hh_l3”…
总结
如无必要,尽量用 load_state_dict 的方式来加载模型,这种方式更稳定
用 torch.load() 的话,很容易报各种各样的错误