一、保存
# 保存模型到路径
torch.save(Batch_Net(28*28, 300, 100, 10), r'C:\Users\11868\Desktop\net.pth')
# 保存模型的参数
torch.save(model.state_dict(), r'C:\Users\11868\Desktop\state_dict.pth')
注意:若模型初始化需要指定参数,则保存时要添加参数。
二、载入
# 加载模型
model = torch.load(r'C:\Users\11868\Desktop\net.pth')
# 加载参数
model.load_state_dict(torch.load(r'C:\Users\11868\Desktop\state_dict.pth'))
model.eval() # 将模型改为测试模式
注意:必须调用model.eval(),以便在运行推断之前将dropout和batch规范化层设置为评估模式。如果不这样做,将会产生不一致的推断结果。