一:只保存和加载模型参数
1 . 保存模型参数:
import torch
torch.save(model.state_dict(), 'save_path_name.pth')
2 . 加载模型参数:
import torch
import torch.nn as nn
model.load_state_dict(torch.load('save_path_name.pth'), strict=True)
方式二:保存和加载整个模型(模型结构和模型参数)
1 . 保存模型:
import torch
torch.save(model, 'save_path_name.pth')
2 . 加载模型:
import torch
import torch.nn as nn
model = torch.load('save_path_name.pth')
# 保存模型到路径
torch.save(Batch_Net(28*28, 300, 100, 10), r'C:\Users\11868\Desktop\net.pth')
# 保存模型的参数
torch.save(model.state_dict(), r'C:\Users\11868\Desktop\state_dict.pth')
————————————————
版权声明:本文为CSDN博主「Answerlzd」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/Answer3664/article/details/98084300
# 加载模型
model = torch.load(r'C:\Users\11868\Desktop\net.pth')
# 加载参数
model.load_state_dict(torch.load(r'C:\Users\11868\Desktop\state_dict.pth'))
model.eval() # 将模型改为测试模式