(1)原文:
(2)问题展示(这是错误代码!):
import torch
model = torch.load('model_method1.pth')
#报错信息:AttributeError: Can't get attribute 'Net' on <module '__main__' from 'D:\\在D盘的pytorch项目\\PycharmProjects\\pytorch\\Minist\\test.py'>
(3)解决方法:
1.找到使用torch.save的那个python文件(我的是“MNIST1_train.py”)
2.把这个文件里面的网络类(我的是Net类)导入到需要torch.load的文件(我的是test.py)
3.在test.py中加入代码
有问题请联系:chufeng0105@qq.com
import torch
from MNIST1_train import Net
model = torch.load('model_method1.pth')
(4)成功展示