AMSoftmax学习

AMSoftmax

我们来讲一下多类别分类问题。

首先我们好好介绍一下softmax loss function。(其实这个名字可以叫CrossEntropy交叉熵函数)。比如说模型输出数据的维度是1*n,但是在此次训练中数据集的类别是m,为了模型输出的维度和输出数据的维度相匹配,我们利用了一个W矩阵,将输出的feat转为(1*m)维度。它的公式也就是由如下所示。

   (1)

    

之后再用softmax激活函数将prey的数据通过如下的式子推导出来。

  (2)

 

注意得到的公式是只是prey中的某一个维度的值罢了。

而实际的y是由独热编码构成的。所以y其实是其中对应正确类别为1。

接下来我们来介绍AMsoftmax。

之前的softmax loss得到的只是一个vector(基准)但是类间的距离却差别不是很大,而由下图中的左图我们可以看出类内的点variance比较大,分散比较零散。

而右图的Additive Margin Softmax就缩小了类内间距,扩大了类间间距

我在这里理解最后全连接层训练的参数W权重矩阵其实就是每一类的中心向量的汇总,然后将输入的数据x映射到待分类的线性空间上(以中心向量作为基准向量)。但是如此映射Softmax loss如上文提到的分类得到的只是一个向量,为了加大不同类之间的距离,就引用了

这个公式,首先我们来看cos角是什么意思,首先我们回溯到之前的全连接层。

全连接层的 其中假设||W||、||x||为1,而||b||=0,那么得到的Wx相乘结果其实就是余弦角cos值。接下来我们明白了cos值的缘由,我们就聚焦于ϕθ的效果,首先我之前说过,我假设分别为两种类别用于分类的中心向量。而原先训练得到的 但是现在训练得到的是'

cosθ'=cosθ+m (这样才能抵消原先所作的减法)也就是说cosθ'的值需要比原来更大,所以容忍的角度自然也就更小了。

好了既然我们已经知道m的作用了,接下来我们来看看完整的AMSoftmax中preyi概率的计算。

我们可以看到所有都乘以了一个s系数。我们假设一个类间原本的容忍角度是

但是如今乘了一个s系数,导致preyi的概率变化会比以前剧烈,也就是说类间的样本与基准向量角度有略微的偏移会对概率预测的影响比之前没乘系数大,所以按理来说,一个类间容忍角度会变小。

以上只是一名普通大学生自己的翻阅资料学习,得到的感悟如有错误希望纠正。

是的,Pytorch中可以直接调用amsoftmax。 在PyTorch中,可以使用nn.CrossEntropyLoss()函数来计算softmax输出与实际标签之间的交叉熵损失。但是,如果你想使用amsoftmax,需要自定义损失函数。 以下是一个简单的amsoftmax实现示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class AMSoftmax(nn.Module): def __init__(self, in_feats, n_classes, m=0.35, s=30.0): super(AMSoftmax, self).__init__() self.m = m self.s = s self.in_feats = in_feats self.n_classes = n_classes self.weight = nn.Parameter(torch.Tensor(in_feats, n_classes)) nn.init.xavier_uniform_(self.weight) def forward(self, x, label): x_norm = F.normalize(x, p=2, dim=1) w_norm = F.normalize(self.weight, p=2, dim=0) logits = x_norm.mm(w_norm) target_logits = logits[torch.arange(0, x.size(0)), label].view(-1, 1) m_hot = torch.zeros_like(logits).scatter_(1, label.view(-1, 1), self.m) logits_m = logits - m_hot logits_scaled = logits_m * self.s loss = nn.CrossEntropyLoss()(logits_scaled, label) return loss ``` 在此实现中,我们通过继承nn.Module来创建一个自定义的AMSoftmax层。在前向传播中,我们首先使用F.normalize()函数对输入特征x和权重矩阵w进行L2归一化。然后,我们将二者相乘,得到logits。接着,我们从logits中提取出与目标标签对应的logit,并将其视为target_logits。接下来,我们创建一个大小与logits相同的张量m_hot,其中每个样本的目标类别位置用值为m的标量替换。最后,我们从logits_m中减去m_hot,然后将差乘以s,以得到缩放后的logits。最后,我们使用自定义的交叉熵损失函数计算损失并返回。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值