人脸识别ArcFace损失函数(代码)
class ArcLoss1(nn.Module):
def __init__(self, class_num, feature_num, s=10, m=0.1):
super().__init__()
self.class_num = class_num
self.feature_num = feature_num
self.s = s
self.m = torch.tensor(m)
self.w = nn.Parameter(torch.rand(feature_num, class_num))
def forward(self, feature):
feature = F.normalize(feature, dim=1)
w = F.normalize(self.w, dim=0)
cos_theat = torch.matmul(feature, w) / 10
sin_theat = torch.sqrt(1.0 - torch.pow(cos_theat, 2))
cos_theat_m = cos_theat * torch.cos(self.m) - sin_theat * torch.sin(self.m)
cos_theat_ = torch.exp(cos_theat * self.s)
sum_cos_theat = torch.sum(torch.exp(cos_theat * self.s), dim=1, keepdim=True) - cos_theat_
top = torch.exp(cos_theat_m * self.s)
divide = (top / (top + sum_cos_theat))
return divide
class ArcLoss2(nn.Module):
def __init__(self, feature_dim=2, cls_dim=10):
super().__init__()
self.W = nn.Parameter(torch.randn(feature_dim, cls_dim))
def forward(self, feature, m=1, s=10):
x = F.normalize(feature, dim=1)
w = F.normalize(self.W, dim=0)
cos = torch.matmul(x, w)/10
a = torch.acos(cos)
top = torch.exp(s*torch.cos(a+m))
down2 = torch.sum(torch.exp(s*torch.cos(a)), dim=1, keepdim=True)-torch.exp(s*torch.cos(a))
out = torch.log(top/(top+down2))
return out
代码先放这儿,后面再来解释