详细链接:https://erogol.com/online-hard-example-mining-pytorch/
OHEM通过减少计算成本来选择难例,提高网络性能。它主要用于目标检测。假设你想训练一个汽车检测器,并且你有正样本图像(图像中有汽车)和负样本图像(图像中没有汽车)。现在你想训练你的网络。实际上,你会发现负样本数量会远远大于正样本数量,而且大量的负样本是比较简单的。因此,比较明智的做法是选择一部分对网络最有帮助的负样本(比较有难度的,容易被识别为正样本的)参与训练。难例挖掘就是用于选择对网络最有帮助的负样本的。
通常来说,通过对网络训练进行一定的迭代后得到临时模型,使用临时模型对所有的负样本进行测试,便可以发现那些loss很大的负样本实例,这些实例就是所谓的难例。但是这种查找难例的方法,需要很大的计算量,因为负样本图像可能会很多;另外这一方法可能是次优的,当你进行难例挖掘的时候,模型的权重是固定的,当前权重下的难例未必适用于接下了的迭代(这句话不太理解)。也就是说,这里假设你选择的所有难例负样本对下一次迭代都是有用的,直到下一次难例选择。这是一个不完美的假设,尤其是对于大型数据集而言。
OHEM通过批量难例选择选择来解决上述两个问题。给定batch-size K,前向传播保持不变,计算损失。然后,选择M(M<K)个高损失值的实例,仅使用这M个实例的损失进行反向传播。
OHEM的具体pytorch实现代码如下:
import torch as th
class NLL_OHEM(th.nn.NLLLoss):
""" Online hard example mining.
Needs input from nn.LogSotmax() """
def __init__(self, ratio):
super(NLL_OHEM, self).__init__(None, True)
self.ratio = ratio
def forward(self, x, y, ratio=None):
if ratio is not None:
self.ratio = ratio
num_inst = x.size(0)
num_hns = int(self.ratio * num_inst)
x_ = x.clone()
inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()
for idx, label in enumerate(y.data):
inst_losses[idx] = -x_.data[idx, label]
#loss_incs = -x_.sum(1)
_, idxs = inst_losses.topk(num_hns)
x_hn = x.index_select(0, idxs)
y_hn = y.index_select(0, idxs)
return th.nn.functional.nll_loss(x_hn, y_hn)