代码:
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
- easy_margin的解释
只对cosine > 0的项添加margin惩罚,虽然函数整体不再满足单调递减的性质,但是总体上绝大部分样本与w的夹角小于pi/2,所以影响不会太大。
- hard_margin的解释
对代码中else部分的解释,整体满足单调递减的性质。
参考:
https://github.com/deepinsight/insightface/issues/22
https://github.com/niliusha123/Margin-based-Softmax/blob/main/ArcFace-easy_margin%E7%9A%84%E8%AF%81%E6%98%8E.jpg
https://github.com/deepinsight/insightface/issues/108