ECCV2020-RBF-Softmax | Learning Deep Representative Prototypes with Radial Basis Function Softmax

在这里插入图片描述

Abstract:

深度神经网络在学习用于视觉分类的特征表示方面取得了显著成功。但是,通过softmax交叉熵损失x学习到的深层特征通常显示出类内的多样性。我们认为,由于传统的softmax损失旨在仅优化类内距离和类间距离(logits)之间的相对差异,因此,即使在以下情况下,也无法获得具有代表性的类原型(类权重/中心)来规范类内距离,即使训练是收敛的。先前的努力通过引入辅助正则化损失来缓解此问题。但是这些修改后的损失主要集中在优化类内的紧凑性,而忽略了保持不同类原型之间的合理关系。这些导致模型较弱,并最终限制了它们的性能。为了解决这个问题,本文介绍了一种新颖的Radial Basis Function (RBF)距离来替代softmax损失函数中常用的内积,从而可以通过重塑来自适应地分配损失以调整类内和类间距离相对差异,从而创建更具代表性的类原型以改进优化。提出的RBF-Softmax损失函数不仅有效地减少了类内距离,稳定了训练行为,并保留了原型之间的理想关系,而且还大大提高了测试性能。在包括MNIST,CIFAR-10 / 100和ImageNet在内的视觉识别基准上进行的实验表明,与交叉熵和其他最新分类损失相比,所提出的RBF-Softmax可获得更好的结果。

创新:

在这里插入图片描述

如上图所示。本文提出Radial Basis Function (RBF)距离来对类内和类间距离进行正则化,有效调整类内和类间距离的合理化,最终目的是学习到了更好的类中心。

RBF-Softmax loss:

本文提出了Radial Basis Function kernel distance(RBF-score),在xi和Wj之间进行测量,以测量样本特征xi与不同类别的权重Wj之间的相似度。
在这里插入图片描述
其中di,j是xi与Wj之间的欧几里得距离,而γ是超参数。 与无界的欧几里得距离和内积相比,RBF分数随着欧几里得距离的增加而减小,其值的范围从0(当di,j→∞时)到1(当xi = Wj时)。 直观地,RBF得分很好地测量了xi和Wj之间的相似度,并且可以用作softmax交叉熵损失函数中的对数。所以本文定义RBF-Softmax loss如下:
在这里插入图片描述

网络训练开始的时候,对于类内的logits比较大的情况,本文通过RBF 距离可以将非常大的欧式距离映射成相对小的RBF-score,造成了在训练的开始阶段,类内的偏差就会显著的变小,有效降低类内的多样性。此外,在训练的后期,传统的softmax loss给出的概率很容易就可以到1。RBF LOSS的情况下的概率很难到1,这样可以持续的进行优化。

代码:

class RBFLogits(nn.Module):
    def __init__(self, feature_dim, class_num, scale, gamma):
        super(RBFLogits, self).__init__()
        self.feature_dim = feature_dim
        self.class_num = class_num
        self.weight = nn.Parameter( torch.FloatTensor(class_num, feature_dim))
        self.bias = nn.Parameter(torch.FloatTensor(class_num))
        self.scale = scale
        self.gamma = gamma
        nn.init.xavier_uniform_(self.weight)
        
    def forward(self, feat, label):
        diff = torch.unsqueeze(self.weight, dim=0) - torch.unsqueeze(feat, dim=1)
        diff = torch.mul(diff, diff)
        metric = torch.sum(diff, dim=-1)
        kernal_metric = torch.exp(-1.0 * metric / self.gamma)
        if self.training:
            train_logits = self.scale * kernal_metric
            # ###
            # Add some codes to modify logits, e.g. margin, temperature and etc.
            # ###
            return train_logits
        else:
            test_logits = self.scale * kernal_metric
            return test_logits

Experiments:

1.Minist与不同损失对比:
在这里插入图片描述

2.ImageNet:在这里插入图片描述

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值