深度学习模型的集成

目前,了解到的深度学习模型集成的方法主要有两种:

第一:平均checkpoints

即对存储好的多个checkpoint中的参数求平均,然后重新保存。这里的多个checkpoint可以是同一份数据训练,模型收敛后存储的多个模型;也可以是不同的数据训练得到的模型。

该方法的优点是可以提高单个模型的效果,并且推理的速度和存储都不会发生变化。缺点在于待平均的模型的结构要完全一致,另外提升的效果不会很大。以pytorch为例,代码如下:

import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

dir_path = "/save_checkpoint_path/"
# 存储模型的路径
start_id = 40  # 待求平均的模型开始索引
end_id = 50  # 待求平均的模型结束索引

models = []
for i in range(start_id + 1, end_id):
    model_path = dir_path + "checkpoint_" + str(i) + ".pt"
    models.append(model_path)
#
checkpoint_path = os.path.join(dir_path, "checkpoint_" + str(start_id) + ".pt")
state = torch.load(checkpoint_path)

count = 0
for cpt in models:
    count += 1
    tmp_state = torch.load(cpt)
    for k in tmp_state:
        state[k] += tmp_state[k]
for k in state:
    state[k] = state[k] / (count + 1)

new_checkpoint_path = dir_path + "/checkpoint_average_point.pt"
torch.save(state, new_checkpoint_path)
# 存储平均之后的checkpoint
print(state)

第二:平均概率

基于深度学习的任务基本都是要在模型的最后一层求得概率或者未归一化的概率。 当有多个模型,并且模型最后预测的类别一致,可以对多个模型预测的概率求平均。这里的模型可以是同一份数据训练的多个模型;也可以是不同的数据训练得到的模型。

这种方法一般效果提升比较大,缺点在于速度和内存都会增加,模型越多需要的内存越大。以pytorch为例,代码如下:

import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

dir_path = "/save_checkpoint_path/"
# 存储模型的路径
start_id = 40  # 待求平均的模型开始索引
end_id = 50  # 待求平均的模型结束索引

models = []
for i in range(start_id, end_id):
    model_path = dir_path + "checkpoint_" + str(i) + ".pt"
    models.append(model_path)

input_x = [1, 2, 3]
# 假设input_x 是模型的输出
avg_pobs = None
for model in models:
    net = torch.load(model)
    prob = net(input_x) 
    avg_pobs.add_(prob)
avg_pobs.div_(len(models))
  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旺旺棒棒冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值