深度学习模型的保存与加载

0. 前言

(1)三个重要的函数:

torch.save: 将序列化的对象存储到硬盘中.

torch.load:该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也支持设备加载数据.

torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典

(2)什么是state_dict?

在 PyTorch 中,state_dict 是一个 Python 字典对象,它将每个层的参数(权重和偏差)映射到对应的张量。state_dict 通常用于保存和加载模型的权重参数,并且可以非常方便地将模型的状态保存到磁盘上或者在不同的 PyTorch 程序之间共享模型的权重参数。

state_dict 的键是层的名称,值是层的权重和偏差。

请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)才在模型的state_dict中具有条目。优化器对象(torch.optim)还具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

1. 模型保存

1. 权重参数保存的三种方式

(1)第一种: 将网络模型和对应的参数保存在一起;( pickle 包)
# 储存模型
torch.save( net123,"./weights/All_in.pth")

# 加载模型
net123 = torch.load("./weights/All_in.pth")
# 额外需要的操作
model.eval()
(2)第二种: 模型和参数分离, 单独的保存模型的权重参数;(state_dict方式)【推荐

推荐, 便于网络模型修改后, 提取出对应层的参数权重;

net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;

# 保存模型及参数
torch.save(net123.state_dict(),'./weights/epoch_weight.pth')

# 加载模型
net123 = module.CustomModel(*args, **kwargs)
net123.load_state_dict(torch.load('epoch_weight.pth'))
model.eval()

注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先torch.load()该模型, 然后再使用 load_state_dict().
 

(3)第三种: 除了权重参数, 用于模型训练的超参数也保存其中(checkpoint方式)。
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# 加载

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便.

 2.跨平台参数保存与加载

save on CPU , load on GPU

# 保存
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

# 加载
# build
        encoder = TransformerModel(params, dico, is_encoder=True, with_output=False)  # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
        decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
        
# reload pretrained word embeddings
        if params.reload_emb != '':
        	# 表示加载预训练模型
            word2id, embeddings = load_embeddings(params.reload_emb, params)
            set_pretrain_emb(encoder, dico, word2id, embeddings)
            set_pretrain_emb(decoder, dico, word2id, embeddings)
            set_pretrain_emb(model2, dico, word2id, embeddings)

        # reload a pretrained model
        if params.reload_model != '':
            enc_path, dec_path = params.reload_model.split(',')
            assert not (enc_path == '' and dec_path == '')

            # reload encoder
            if enc_path != '':
                enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                # 预训练模型是在 GPU 上训练的
                enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                # 导入存储的文件的模型
                if all([k.startswith('module.') for k in enc_reload.keys()]):
                    enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
                # 这个过程相当于将model 反序列化为 state_dict的形式
                encoder.load_state_dict(enc_reload, strict=False)
                # 这个后面的 strict=False 就是对 encoder 与 enc_reload.state_dict之间差异进行处理, 如果encoder 的模型结构与 enc_reload模型结构
                # 不一样的时候, 就会向 encoder 转化, 也就是 encoder 不包含的层就不会导入, 例如这里 enc_reload 就是一个完整的 Transformer 模型, 但是
                # encoder 是不包含输出部分的, 所以就不会加载这部分

reference:

https://blog.csdn.net/kingonlyuserjava/article/details/106429755

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值