这是因为pytorch1.6以后默认把模型存储为压缩文件,而老版本不是压缩文件,因此识别不了
解决方法:可以在新版本中将模型加载后,再存储为非压缩文件,再用老版本加载:
def save(self, apath, epoch, is_best=False):
target = self.get_model()
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_latest.pt'),
_use_new_zipfile_serialization=False
)
if is_best:
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_best.pt'),
_use_new_zipfile_serialization=False
)