insightface源码中arcface代码段理解

在看insightface源码中,遇到arcface损失函数的实现,感觉非常难以理解,参考了下面的博客,自己进行了一波强势自我解释,目前理解一部分,谨在此进行记录,以防忘记。
博客地址:http://www.cnblogs.com/darkknightzh/p/8525287.html

代码原文如下:

  all_label = mx.symbol.Variable('softmax_label') 
  gt_label = all_label
  # extra_loss = None
  # 重新定义fc7的权重
  _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
   ## ...省略部分代码
  elif args.loss_type==4: #ArcFace
    s = args.margin_s # 参数s, 64
    m = args.margin_m # 参数m, 0.5

    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    # 权重归一化
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')    # shape = [(类别数目, 512)]
    # 特征归一化,并放大到 s*x
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') #args.num_classes:85164
    
    zy = mx.sym.pick(fc7, gt_label, axis=1) #fc7每一行找出gt_label对应的值, 即s*cos_t

    cos_t = zy/s # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta)
    cos_m = math.cos(m)
    sin_m = math.sin(m)
    mm = math.sin(math.pi-m)*m     #sin(pi-m)*m = sin(m) * m  0.2397
    #threshold = 0.0
    threshold = math.cos(math.pi-m)        # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m)    -0.8775825618903726
    if args.easy_margin:      # 将0作为阈值,得到超过阈值的索引
      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
    else:
      cond_v = cos_t - threshold    # 将负数作为阈值
      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    body = cos_t*cos_t  # cos_t^2 + sin_t^2 = 1
    body = 1.0-body
    sin_t = mx.sym.sqrt(body)
    new_zy = cos_t*cos_m    #cos(t+m) = cos(t)cos(m) - sin(t)sin(m)
    b = sin_t*sin_m
    new_zy = new_zy - b
    new_zy = new_zy*s    # s*cos(t + m)
    if args.easy_margin:
      zy_keep = zy    # zy_keep为zy,即s*cos(theta)
    else:
      zy_keep = zy - s*mm   # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m)
    new_zy = mx.sym.where(cond, new_zy, zy_keep)   #cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)

    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)    # 对应yi处为new_zy - zy
    fc7 = fc7+body   #对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)

上面代码中,在第18行首先重新初始化了f7层的权重,然后在15、16行对权重进行了归一化,并对特征embedding也进行了归一化操作,同时乘以s进行了放大,这块内容在Normface中进行了说明,属于常规操作。
之后,在第20行要计算zy,这里取的是输出结果中,每一个类别所对应的结果,也就是 s × cos ⁡ t s \times \cos{t} s×cost
然后在22行求出 cos ⁡ t \cos{t} cost,下面就是要对这个 cos ⁡ t \cos{t} cost进行变化,变成arcface中提出的 cos ⁡ ( t + m ) \cos(t+m) cos(t+m)

在论文中,作者实际计算 cos ⁡ ( t + m ) \cos{(t+m)} cost+m用的是下面这个公式: cos ⁡ ( m + t ) = cos ⁡ m c o s t − s i n m s i n t \cos{(m+t)} = \cos{m}cos{t} -sin{m}sin{t} cos(m+t)=cosmcostsinmsint
所以,关键是怎么算?
下面的一坨代码都是为了计算这个式子,之所以这么复杂,是因为 cos ⁡ ( t + m ) \cos(t+m) cos(t+m)并非是单调的。我们知道,在L-softmax,A-softmax中,作者为了解决cos函数不单调的时候,提出了使用 ( − 1 ) m cos ⁡ m t − 2 k (-1)^m\cos{mt}-2k (1)mcosmt2k这个函数来代替原始的cos函数,就是保证在训练过程中,保证函数一直保持在递减的区间,这样算法才是有效的。这里arcface在实现中也进行了处理,只不过不是使用上面的函数而已。

作者在27行计算的时候,使用了一个叫做“threshold”的变量来进行约束。这块内容推敲一下还是可以理解的。因为在原始的输出 cos ⁡ t \cos{t} cost中,t的取值范围为[0, π \pi π],那么如果直接对t加上m,则可能会超过 π \pi π,那么这时候函数就不单调了,所以要在保证 cos ⁡ ( t + m ) \cos(t+m) cos(t+m)工作在[0, π \pi π]范围内。这时候的限制条件就变成了 − m ≤ t ≤ π − m -m\leq t \leq \pi-m mtπm。那么对于cos函数来说,就是要求 cos ⁡ t ≤ cos ⁡ ( π − m ) \cos{t} \leq \cos(\pi -m) costcos(πm)。也就是说, c o s ( π − m ) cos(\pi -m) cos(πm) cos ⁡ t \cos{t} cost的上界,如果超过这个边界,凡是在 c o s ( π − m ) − cos ⁡ t cos(\pi -m) -\cos{t} cos(πm)cost区间可能存在的 t 值,都让其强制变小,那么其对应的 cos ⁡ t \cos{t} cost值最大也就是 cos ⁡ t − c o s ( π − m ) \cos{t} - cos(\pi -m) costcos(πm),这样就保证了 t+m 工作在[0, π \pi π]之间。但是要注意,这里没有不是对 cos ⁡ t \cos{t} cost进行改变,而是将其约束后的值保存在了31行cond变量中,将这个变量送入激活函数中。

然后,39行计算了new_zy变量,这个变量就是 s cos ⁡ ( t + m ) s\cos(t+m) scos(t+m),然后在43行这里计算了一个zy_keep = zy - s*mm,而mm在第25行计算了, m m = m s i n ( π − m ) mm = msin(\pi -m) mm=msin(πm)。所以这里的zy_keep 实际上是 cos ⁡ ( t ) − m sin ⁡ ( m ) \cos(t) - m\sin(m) cos(t)msin(m)。这个地方的确是非常费解,我从insightface中的各种版本的源码中都进行查阅,大概的意思是,当 t 不在[0, π \pi π]范围内时,作者使用的cosface中的实现,但是cosface中的实现用的是 cos ⁡ ( t ) − m \cos(t)-m cos(t)m,而不是 cos ⁡ ( t ) − sin ⁡ ( m ) × m \cos(t) -\sin(m) \times m cos(t)sin(m)×m,所以这里还是不太理解,希望以后有能看到某位大神的解释吧。

下面44行就是的new_zy就是选取满足条件的 cos ⁡ ( t + m ) \cos(t+m) cos(t+m),如果没有超出范围,那么就保持原样,如果超出了,就用它cosface代替。

当然,上面的分析实际上感觉作者在这个地方有点多余,因为论文中作者已经分析过了,角度不会出现在150度以上,所以基本也不会出现越界的可能。所以在后面当loss_type=5的时候,作者也没有进行其他的约束,直接求 cos ⁡ ( t ) \cos(t) cos(t) 的反三角,然后进行常规的 cos ⁡ ( t + m ) \cos(t+m) cos(t+m),没有进行其他的约束。但不管怎么说,在训练arcfac的时候,作者是用了一些小trick的。
各loss联合训练

elif args.loss_type==5:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
      if args.margin_a==1.0 and args.margin_m==0.0:
        s_m = s*args.margin_b
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
      else:
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)  #这里直接反三角
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b  #这里直接计算cos(t+m)
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body
  • 7
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值