原理
Arc-SoftmaxLoss = Arc-Softmax + NLLLoss。
softmax 是通过角度分类的,Arc-Softmax 加宽了角度间的分界线,从而达到加大类间距的目的。
softmax
V = wx = cosθ·||w||·||x||( cosθ = 二范数归一化后的欧氏距离 = wx / ||w||·||x||)。
Si = exp(cosθ·||w||·||x||) / ∑exp(cosθ·||w||·||x||)。
Arc-Softmax(加大 θ 角度)
实验(手写数字识别)
数据集:MNIST。
网络结构:CNN + 特征全连接(输出2个特征值)+ 分类全连接(输出10个特征值)。
优化器:Adam。
损失函数:NLLLoss(NLLLoss = CrossEntropyLoss - softmax),需要输入做过 logsoftmax 的预测结果。
输出:one-hot 类型,结果为最大的索引值。
自定义激活函数(Arc-Softmax)
import torch
from torch import nn
from torch.nn import functional as F
class ArcSoftmax(nn.Module):
def __init__(self, cls_num, feature_num):
super().__init__()
# x[n,v] · w[v,c]
self.w = nn.Parameter(torch.randn((feature_num, cls_num)))
def forward(self, x, s, m):
# x → x / ||x||
x_norm = F.normalize(x, dim=1)
# w → w / ||w||
w_norm = F.normalize(self.w, dim=0)
# cosθ = 二范数归一化后的 x·w = (x / ||x||)(w / ||w||)
# /10:防止梯度爆炸,要在后边乘回来
cosa = torch.matmul(x_norm, w_norm) / 10
# 反余弦求角度
a = torch.acos(cosa)
# 全部:torch.sum(torch.exp(s * cosa * 10), dim=1, keepdim=True)
# 当前:torch.exp(s * cosa * 10)
# 加大角度:torch.exp(s * torch.cos(a + m) * 10)
arcsoftmax = torch.exp(s * torch.cos(a + m) * 10) / (
torch.