【Pytorch-6】-模型保存与加载

其实Pytorch模型保存还是挺简单的,但是不同方式也有优劣之分吧。有时候,我们不仅仅需要保存模型参数,而有时需要保存训练的所有现场,包括优化器的内容。即有时候是只保存参数,但有时候需要保存模型训练的全过程。

1. 保存state_dict(参数和主要信息)

我们实际上保存的是模型的参数,没有保存模型的结构的完整信息。

即,保存的模型是以字典形式保存的,所以被称作为state_dict。上面实际上我们按照已经定义好的模型进行加载,所以使用model.load_state_dict。其中的键信息实际是原本模型的层次的名字,因此模型在重新读取的时候,需要我们先实例化完全一致的结构,再进行参数的加载。

如果model是pytorch的nn.module继承而来的,那么如下:

model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)

这里有.pth的格式存储,还有.pkl格式,以及.pt的格式。

之后,如果要进行推理或者使用时加载模型,只需要模型的结构对应,就可以直接加载:

model.load_state_dict(torch.load(args.model_path))
# args.model_path就是模型的路径字符串,比如'model.pth'

总结如下:

  • 保存模型时调用 state_dict() 获取模型的参数,而不保存结构
  • 加载模型时需要预先实例化一个对应的结构
  • 加载模型使用 load_state_dict 方法,其参数不是文件路径,而是 torch.load(PATH)

2. 存取整个模型

这是完整的存储了模型的信息的方法,包括模型的参数信息、模型的结构信息、参数等等所有内容。和方法一相比,弊端是会占用更大的信息,优势是,我们不需要知道文件中的模型究竟是什么样的,直接读取即可使用了:

torch.save(model, PATH)

model = torch.load(PATH)

3.存取checkpoint

有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint,往往除了模型,还要保存它的epochlossoptimizer等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。

这里我们主要将多个内容放入一个字典进行保存:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载的时候,我们需要将各个对应的元素按照原本的类别,进行数据初始化,例如优化器必须还是之前的优化器,模型还是之前的模型结构(主要这里例子是state_dict,不然直接保存模型也是可以的)

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

4. 设备不同时的转换问题

我们时常会涉及到,在有GPU的服务器进行训练,但是在CPU上进行推理和使用的情况。正常的CPU训练、CPU加载或者GPU训练、GPU使用,都是没问题的,主要是设备不同时的问题。

GPU训,GPU加载

最为正常和一般的情况,照常操作,不过还是别忘记把模型放到GPU上去。

GPUidx=0
device = torch.device('cuda:{}'.format(GPUidx) if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64     # number of data points in each batch
N_EPOCHS = 15       # times to run the model on complete data
INPUT_DIM = 28 * 28 # size of each input
HIDDEN_DIM = 256    # hidden dimension
LATENT_DIM = 20     # latent vector dimension

encoder = Encoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM) # encoder
decoder = Decoder(LATENT_DIM, HIDDEN_DIM, INPUT_DIM) # decoder
VAEmodel = VAE(encoder, decoder).to(device)# vae

VAEmodel.load_state_dict(torch.load(modelpath))

GPU训练,CPU加载

保存的行为一致,我们只需要在torch.load时,对相应的参数map_location进行设置即可:

torch.save(net.state_dict(), PATH)

device = torch.device("cpu")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

CPU训练,GPU加载

虽然一般不太可能,但还是啰嗦一下

torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
# or
loaded_net.to(device)
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值