论文信息
- ArcFace: Additive Angular Margin Loss for Deep Face Recognition[2019-CVPR]
- Author : Jiankang Deng, Jia Guo, Niannan Xue, Stefanos Zafeiriou
- Citation : 280+
- Github : metricface-pytorch
关键技术
Arcface-pytorch
class Arcface(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)
"""
def __init__(self, in_features, out_features, s = 30.0, m = 0.50, easy_margin = False, use_gpu = False):
super(Arcface, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = torch.FloatTensor(out_features, in_features)
if use_gpu:
self.weight = self.weight.cuda()
self.weight = Parameter(self.weight)
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, target):
# cos(a+b) = cos(a)cos(b)-sin(a)sin(b)
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine) # enhance similar class
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # pull similar, push diff
one_hot = torch.zeros(cosine.size(), device=input.device)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
Arcface vs (Sphereface, Cosface) ?
Experiment