模型保存/加载的四种方法
1.保存/加载状态字典(state_dict)
2.保存/加载整个模型(entire model)
3.保存/加载checkpoint信息
4.保存/加载多个模型到一个文件
注:详情请参阅 pytorch 官方文档 链接
"""模型保存与加载
方法1:保存/加载状态字典(state_dict)
该方法具有更大的灵活性,推荐使用
方法2:保存/加载整个模型(entire model)
方法3:保存/加载checkpoint
该方法以字典形式存储模型信息,推荐训练过程使用
方法4:保存/加载多个模型到一个文件
该方法可用于模型重用(即使用预训练模型)
主要针对像GAN/sequence-to-sequence model/an ensemble of models(一组模型) 这样包含多个torch.nn.Modules的模型
"""
import torch
import torch.nn as nn # 模型构建模块
import torch.optim as optim # 优化器模块
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(