在使用pytorch加载模型时报错:
torch.save(old_model, PATH)
new_model = torch.load(PATH)
AttributeError: Can't get attribute 'Net' on <module '__main__'>
解决办法:
1、将类的定义添加到加载模型的这个py文件中,这个方法有点。。。
2、使用官方推荐的方法:https://pytorch.org/docs/master/notes/serialization.html
只保存,加载模型的权重参数:
torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))