问题表现
运行torch.load("xxx.pth")
加载模型参数,报错:
RuntimeError: xxx.pth is a zip archive(did you mean to use torch.jit.load()?)
原因
xxx.pth
来自pytorch1.6或更高的版本。1.6之后pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。
解决方案
在pytorch1.6版本下运行:
state_dict = torch.load("xxx.pth")
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)
则将权重文件转换成非zip格式。