Balanced MSE 使用指南

Balanced MSE 使用指南


前言

论文地址:https://arxiv.org/abs/2203.16427
代码地址:https://github.com/jiawei-ren/BalancedMSE

本文旨在帮助大家更快将Balanced MSE部署到自己的模型中,解决数据中存在的不平衡问题

Balanced MSE的GAI解法,需要先对数据的标签进行拟合,然后得到一次静态的GMM分布,比较难以适配流式训练下情况,同时GAI的做法依旧有MSE项,依旧没有解决MSE训练不稳定的问题;BMC的做法,从最终的损失函数上来看,是一个比较理想的情况,一个是自适应数据分布,另外一个就是损失函数是softmax形式,解决了MSE的问题

基于此,本文主要介绍如何在我们自己的模型中加入Balanced MSE的BMC实现方式


1、BMC方法介绍

在Batch-based Monte-Carlo (BMC)中,不需要建模训练集上的标签分布,将所有的样本标签看成是训练集标签的随机样本
在这里插入图片描述
将其重新写入Balanced MSE损失函数中,得到
在这里插入图片描述

2、代码实现

将Balanced MSE的BMC用代码实现,如下所示:

def bmc_loss(pred, target, noise_var):
    """Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
    Args:
      pred: A float tensor of size [batch, 1].
      target: A float tensor of size [batch, 1].
      noise_var: A float number or tensor.
    Returns:
      loss: A float tensor. Balanced MSE Loss.
    """
    logits = - (pred - target.T).pow(2) / (2 * noise_var)   # logit size: [batch, batch]
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))     # contrastive-like loss
    loss = loss * (2 * noise_var).detach()  # optional: restore the loss scale, 'detach' when noise is learnable 

    return loss

在使用损失函数时,我们需要将模型的输出size修改为[batch, 1],否则可能出现如下的报错:

RuntimeError: Expected floating point type for target with class probabilities, got Long

针对上面的报错,主要来自于数据的类型。

这里值得注意的是,变量noise_var是一维超参,是可以学习的

但是在我们的实践中发现,将noise_var设置为可学习的超参数后,效果相较于未将noise_var设置为超参数效果更差一些

关于这一点后续还需要更多对比试验进行验证……

3、模型嵌入

那么,我们如何将Balanced MSE损失函数应用到我们自己的模型中呢?

3.1、CPU方式

这里给出一种即插即用的方法,也是作者在GitHub中给出的方式,代码如下所示:

# 定义Balanced MSE Loss(BMC版本)
105 def bmc_loss(pred, target, noise_var):
106     pred = pred.view(-1, 1)
107     target = target.view(-1, 1)                                                                           
108     logits = - 0.5 * (pred - target.T).pow(2) / noise_var
112     loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))
113     loss = loss * (2 * noise_var)
114     return loss
115 
116 
117 class BMCLoss(_Loss):
118     def __init__(self, init_noise_sigma):
119         super(BMCLoss, self).__init__()
121         self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))
122 
123     def forward(self, pred, target):
124         noise_var = self.noise_sigma ** 2
125         return bmc_loss(pred, target, noise_var)
126 
127 init_noise_sigma = 8.0
128 sigma_lr = 1e-2
129 model = Model()
130 criterion = BMCLoss(init_noise_sigma)
131 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
132 optimizer.add_param_group({'params': criterion.noise_sigma, 'lr': sigma_lr, 'name': 'noise_sigma'})

其中两个超参也是我们按照默认值给出,当然大家也可以自己在实验的过程中进行调参,但是我们不建议大家将noise_var打开

3.2、GPU

此外为了方便大家在GPU上使用Balanced MSE,我们也给出了GPU的实现方式,与CPU一样,但我们需要将模型数据加载到GPU上,实现代码如下所示:

# 定义Balanced MSE Loss(BMC版本)
105 def bmc_loss(pred, target, noise_var):
106     pred = pred.view(-1, 1)
107     target = target.view(-1, 1)
108     logits = - 0.5 * (pred - target.T).pow(2) / noise_var
110     loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).cuda())
111     loss = loss * (2 * noise_var)
114     return loss
115 
116 
117 class BMCLoss(_Loss):
118     def __init__(self, init_noise_sigma):
119         super(BMCLoss, self).__init__()
120         self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda"))
122 
123     def forward(self, pred, target):
124         noise_var = self.noise_sigma ** 2
125         return bmc_loss(pred, target, noise_var)
126 
127 init_noise_sigma = 8.0
128 sigma_lr = 1e-2
129 model = Model()
130 criterion = BMCLoss(init_noise_sigma)
131 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
132 optimizer.add_param_group({'params': criterion.noise_sigma, 'lr': sigma_lr, 'name': 'noise_sigma'})
133 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134 model.to(device)

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值