torch模型保存。
模型保存的本质就是利用pickle模块进行序列化。序列化到文件,从文件反序列化回来的对象,要么是Python自定义的对象,要么是本文件中已经定义的类。
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.linear1 = nn.Linear(input_size, input_size * 2)
self.linear2 = nn.Linear(input_size * 2, output_size)
def forward(self, inputs):
inputs = self.linear1(inputs)
output = self.linear2(inputs)
return output
第一种方式
model = Model()
torch.save(model,'./model.pth')
model = torch.load('./model.pth')
第二种方式
model = Model()
torch.save(model.state_dict(), './model_state_dict.pth')
model = Model()
model.load_state_dict('./model_state_dict.pth')