解析ArcFace源码

猫猫与橙子 2019-10-09 17:20:38  893  收藏 2
展开

论文分享,代码复现
主要用于论文翻译分享,代码复现,结合场景数据,提升实际模型性能
猫猫与橙子
¥19.90
分享赚¥1.99 订阅专栏
最近在看arcFace当中pytorch代码的实现,先把InsightFace中ArcFace代码贴出来: 

class Arcface(Module):
    # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599    
    def __init__(self, embedding_size=512, classnum=51332,  s=64., m = 0.5):
        super(Arcface, self).__init__()
        self.classnum = classnum
        self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) #uniform_(-1, 1)服从均匀分布,mul_对应点相乘
        self.m = m # the margin value, default is 0.5
        self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m  # issue 1
        self.threshold = math.cos(math.pi - m)
    def forward(self, embbedings, label):
        # weights norm
        nB = len(embbedings)
        kernel_norm = l2_norm(self.kernel, axis=0)
        # cos(theta+m)
        cos_theta = torch.mm(embbedings, kernel_norm)#进行矩阵乘法
#         output = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        return output
对于torch中一些函数的理解

1)对于self.kernel = Parameter(torch.Tensor(embedding_size, classnum))中,Parameter的作用:

首先可以把Parameter理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.kernel变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。(摘自:原文链接:https://blog.csdn.net/qq_36955294/article/details/88117170)

看了torch官网的解释:
"Variable的一种,常被用于模块参数(module parameter)。

Parameters 是 Variable 的子类。Paramenters和Modules一起使用的时候会有一些特殊的属性,即:当Paramenters赋值给Module的属性的时候,他会自动的被加到 Module的 参数列表中(即:会出现在 parameters() 迭代器中)。将Varibale赋值给Module属性则不会有这样的影响。 这样做的原因是:我们有时候会需要缓存一些临时的状态(state)"

这句话中,embedding_size = 512,classnum是人脸识别的ID数,先使用orch.Tensor,生成一个512×classnum的张量,然后通过Parameter将这个张量转化为可以训练的模型;

2)对于self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)  的理解:

#uniform_(from=-1, to=1) → Tensor    将tensor用从均匀分布中抽样得到的值填充。

# renorm_返回一个张量,包含规范化后的各个子张量,使得沿着2维划分的各子张量的1范数小于1e-5

# mul_用标量值1e5乘以输入input的每个元素,并返回一个新的结果张量;

以上是对pytorch中一些函数的理解;

 

对于arcface公式的代码实现

对于arcFace的实现实际应该是包括两部分,第一部分是cosin函数部分;第二部分就是常规的softmax部分;

在pytorch代码中,第二部分直接有函数实现,是可以直接使用的;所以重点是cosin函数部分的实现;

下面就重点讲解记录一下怎样一步步的实现第一部分代码:

1)对Feature进行了l2 norm,对参数也进行了l2 norm.所以权值参数×feature = cos theta

2)将cos theta夹逼到【-1, 1】之间,因为cos theta的定义域在【0,pi】值域实在【-1,1】之间;

3)计算cos(theta + m)使用到余弦定理;

4)计算完成后,要判断theta是否超出范围,进行数据调整,这一块的判读原理在下图:

(不知道这样理解是否有错?望大佬赐教)

判断后得出一个值为0或1的mask,通过使用cos_theta_m[cond_mask] = keep_val[cond_mask],将超出范围的值使用keep_val表示,加入[cond_mask],是将mask为1(True)位置的元素取出,进行修改;
————————————————
版权声明:本文为CSDN博主「猫猫与橙子」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_22764813/article/details/101437898

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值