0. 前言
(1)三个重要的函数:
torch.save: 将序列化的对象存储到硬盘中.
torch.load:该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也支持设备加载数据.
torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典
(2)什么是state_dict?
在 PyTorch 中,state_dict
是一个 Python 字典对象,它将每个层的参数(权重和偏差)映射到对应的张量。state_dict
通常用于保存和加载模型的权重参数,并且可以非常方便地将模型的状态保存到磁盘上或者在不同的 PyTorch 程序之间共享模型的权重参数。
state_dict
的键是层的名称,值是层的权重和偏差。
请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)才在模型的state_dict中具有条目。优化器对象(torch.optim)还具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
1. 模型保存
1. 权重参数保存的三种方式
(1)第一种: 将网络模型和对应的参数保存在一起;( pickle 包)
# 储存模型
torch.save( net123,"./weights/All_in.pth")
# 加载模型
net123 = torch.load("./weights/All_in.pth")
# 额外需要的操作
model.eval()
(2)第二种: 模型和参数分离, 单独的保存模型的权重参数;(state_dict方式)【推荐】
推荐, 便于网络模型修改后, 提取出对应层的参数权重;
net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;
# 保存模型及参数
torch.save(net123.state_dict(),'./weights/epoch_weight.pth')
# 加载模型
net123 = module.CustomModel(*args, **kwargs)
net123.load_state_dict(torch.load('epoch_weight.pth'))
model.eval()
注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先torch.load()该模型, 然后再使用 load_state_dict().
(3)第三种: 除了权重参数, 用于模型训练的超参数也保存其中(checkpoint方式)。
# 保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
# 加载
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
存储 checkpoints
主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便.
2.跨平台参数保存与加载
save on CPU , load on GPU
# 保存
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
# 加载
# build
encoder = TransformerModel(params, dico, is_encoder=True, with_output=False) # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
# reload pretrained word embeddings
if params.reload_emb != '':
# 表示加载预训练模型
word2id, embeddings = load_embeddings(params.reload_emb, params)
set_pretrain_emb(encoder, dico, word2id, embeddings)
set_pretrain_emb(decoder, dico, word2id, embeddings)
set_pretrain_emb(model2, dico, word2id, embeddings)
# reload a pretrained model
if params.reload_model != '':
enc_path, dec_path = params.reload_model.split(',')
assert not (enc_path == '' and dec_path == '')
# reload encoder
if enc_path != '':
enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
# 预训练模型是在 GPU 上训练的
enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
# 导入存储的文件的模型
if all([k.startswith('module.') for k in enc_reload.keys()]):
enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
# 这个过程相当于将model 反序列化为 state_dict的形式
encoder.load_state_dict(enc_reload, strict=False)
# 这个后面的 strict=False 就是对 encoder 与 enc_reload.state_dict之间差异进行处理, 如果encoder 的模型结构与 enc_reload模型结构
# 不一样的时候, 就会向 encoder 转化, 也就是 encoder 不包含的层就不会导入, 例如这里 enc_reload 就是一个完整的 Transformer 模型, 但是
# encoder 是不包含输出部分的, 所以就不会加载这部分
reference:
https://blog.csdn.net/kingonlyuserjava/article/details/106429755