论文名称:ArcFace: Additive Angular Margin Loss for Deep Face Recognition
论文下载地址:https://arxiv.org/abs/1801.07698
代码(作者本人的开源×欢迎star):https://github.com/wuji3/visiondk(建议直接访问,不要用csdn的github加速计划,会有版本错误)
目录
背景
在人脸识别领域内,有一系列的工作在研究如何基于精心设计的softmax loss function中添加margin,实现class separability。从最开始的triplet loss,到sphereface和cosface等等。triplet loss是人脸领域中开创性的工作,引入了margin的概念实现class separability,但是它的局限性在于1.学习的embedding是在欧式空间,2.只有困难样例才有意义,训练效率低。后面基于softmax的loss,大多更新了欧式embedding到余弦embedding,看到这里可能有人会问,为什么欧式embedding不如余弦embedding呢?因为欧式空间衡量的是相同程度(same),而余弦空间度量的是相似程度(similar),即使同一个identity,不同角度不同时间拍出来的图像,在特征空间内不可能element-wise相同,但确实是同一个id,衡量相似性更加符合人脸的业务。既然是基于余弦的特征空间,如何设计余弦的margin才是算法的核心。给余弦距离加margin,实际上等价于给余弦相似度减margin,那么怎么减掉这个margin呢?
ArcFace
接背景部分,带着“如何给余弦相似度减margin?”这个疑问,我们看看ArcFace是如何解决这个事情的。
对于正例,某个identity的sample和center,它们的相似度是,如果给加一个,加到的位置,那么。回忆一下我们的目标是什么?目标是“如何给余弦相似度减margin”,至此,ArcFace的设计哲学就讲完了。
再理一遍,具体是怎么做的呢?先给求,把算出来,加得到,再求算回去,得到。margin加进去了,在arc-level加进去的,之后呢?用celoss监督模型最大化。注意,本来模型度量positive相似度的能力是,你硬生生给它加了难度,变成了,于是模型开始玩命的最大化,等价于最小化positive的,实现class separability。
那负例呢?负例加margin嘛?如果你有这个疑问,先思考你想给负例加margin的目标是什么?你可能说为了实现class separability。ok,实现class separability要做两件事,1. small intra-class distance, 2. large inter-class distance。通过给负例加margin你要实现哪件事?你只能为了第二件事。你一定说通过给负例加margin,在特征空间中把负例推的更远,实现large inter-class distance。实际上,celoss已经完成了这件事。celoss完不成的事,是第一件small intra-class distance,而正例加margin正弥补了celoss在第一件事的能力缺陷。
代码实现
class ArcFace(Module):
"""Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
"""
def __init__(self, feat_dim, num_class, margin_arc=0.35, margin_am=0.0, scale=32):
super(ArcFace, self).__init__()
self.weight = Parameter(torch.Tensor(feat_dim, num_class))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.margin_arc = margin_arc
self.margin_am = margin_am
self.scale = scale
self.cos_margin = math.cos(margin_arc)
self.sin_margin = math.sin(margin_arc)
self.min_cos_theta = math.cos(math.pi - margin_arc)
def forward(self, feats, labels):
kernel_norm = F.normalize(self.weight, dim=0)
feats = F.normalize(feats)
cos_theta = torch.mm(feats, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1)
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin
# 0 <= theta + m <= pi, ==> -m <= theta <= pi-m
# because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m)
cos_theta_m = torch.where(cos_theta > self.min_cos_theta, cos_theta_m, cos_theta-self.margin_am)
index = torch.zeros_like(cos_theta)
index.scatter_(1, labels.data.view(-1, 1), 1)
index = index.byte().bool()
output = cos_theta * 1.0
output[index] = cos_theta_m[index]
output *= self.scale
return output