在看insightface源码中,遇到arcface损失函数的实现,感觉非常难以理解,参考了下面的博客,自己进行了一波强势自我解释,目前理解一部分,谨在此进行记录,以防忘记。
博客地址:http://www.cnblogs.com/darkknightzh/p/8525287.html
代码原文如下:
all_label = mx.symbol.Variable('softmax_label')
gt_label = all_label
# extra_loss = None
# 重新定义fc7的权重
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
## ...省略部分代码
elif args.loss_type==4: #ArcFace
s = args.margin_s # 参数s, 64
m = args.margin_m # 参数m, 0.5
assert s>0.0
assert m>=0.0
assert m<(math.pi/2)
# 权重归一化
_weight = mx.symbol.L2Normalization(_weight, mode='instance') # shape = [(类别数目, 512)]
# 特征归一化,并放大到 s*x
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') #args.num_classes:85164
zy = mx.sym.pick(fc7, gt_label, axis=1) #fc7每一行找出gt_label对应的值, 即s*cos_t
cos_t = zy/s # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta)
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m #sin(pi-m)*m = sin(m) * m 0.2397
#threshold = 0.0
threshold = math.cos(math.pi-m) # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m) -0.8775825618903726
if args.easy_margin: # 将0作为阈值,得到超过阈值的索引
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - threshold # 将负数作为阈值
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
body = cos_t*cos_t # cos_t^2 + sin_t^2 = 1
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m #cos(t+m) = cos(t)cos(m) - sin(t)sin(m)
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s # s*cos(t + m)
if args.easy_margin:
zy_keep = zy # zy_keep为zy,即s*cos(theta)
else:
zy_keep = zy - s*mm # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m)
new_zy = mx.sym.where(cond, new_zy, zy_keep) #cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff) # 对应yi处为new_zy - zy
fc7 = fc7+body #对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)
上面代码中,在第18行首先重新初始化了f7层的权重,然后在15、16行对权重进行了归一化,并对特征embedding也进行了归一化操作,同时乘以s进行了放大,这块内容在Normface中进行了说明,属于常规操作。
之后,在第20行要计算zy,这里取的是输出结果中,每一个类别所对应的结果,也就是
s
×
cos
t
s \times \cos{t}
s×cost。
然后在22行求出
cos
t
\cos{t}
cost,下面就是要对这个
cos
t
\cos{t}
cost进行变化,变成arcface中提出的
cos
(
t
+
m
)
\cos(t+m)
cos(t+m)。
在论文中,作者实际计算
cos
(
t
+
m
)
\cos{(t+m)}
cos(t+m)用的是下面这个公式:
cos
(
m
+
t
)
=
cos
m
c
o
s
t
−
s
i
n
m
s
i
n
t
\cos{(m+t)} = \cos{m}cos{t} -sin{m}sin{t}
cos(m+t)=cosmcost−sinmsint
所以,关键是怎么算?
下面的一坨代码都是为了计算这个式子,之所以这么复杂,是因为
cos
(
t
+
m
)
\cos(t+m)
cos(t+m)并非是单调的。我们知道,在L-softmax,A-softmax中,作者为了解决cos函数不单调的时候,提出了使用
(
−
1
)
m
cos
m
t
−
2
k
(-1)^m\cos{mt}-2k
(−1)mcosmt−2k这个函数来代替原始的cos函数,就是保证在训练过程中,保证函数一直保持在递减的区间,这样算法才是有效的。这里arcface在实现中也进行了处理,只不过不是使用上面的函数而已。
作者在27行计算的时候,使用了一个叫做“threshold”的变量来进行约束。这块内容推敲一下还是可以理解的。因为在原始的输出 cos t \cos{t} cost中,t的取值范围为[0, π \pi π],那么如果直接对t加上m,则可能会超过 π \pi π,那么这时候函数就不单调了,所以要在保证 cos ( t + m ) \cos(t+m) cos(t+m)工作在[0, π \pi π]范围内。这时候的限制条件就变成了 − m ≤ t ≤ π − m -m\leq t \leq \pi-m −m≤t≤π−m。那么对于cos函数来说,就是要求 cos t ≤ cos ( π − m ) \cos{t} \leq \cos(\pi -m) cost≤cos(π−m)。也就是说, c o s ( π − m ) cos(\pi -m) cos(π−m)是 cos t \cos{t} cost的上界,如果超过这个边界,凡是在 c o s ( π − m ) − cos t cos(\pi -m) -\cos{t} cos(π−m)−cost区间可能存在的 t 值,都让其强制变小,那么其对应的 cos t \cos{t} cost值最大也就是 cos t − c o s ( π − m ) \cos{t} - cos(\pi -m) cost−cos(π−m),这样就保证了 t+m 工作在[0, π \pi π]之间。但是要注意,这里没有不是对 cos t \cos{t} cost进行改变,而是将其约束后的值保存在了31行cond变量中,将这个变量送入激活函数中。
然后,39行计算了new_zy变量,这个变量就是 s cos ( t + m ) s\cos(t+m) scos(t+m),然后在43行这里计算了一个zy_keep = zy - s*mm,而mm在第25行计算了, m m = m s i n ( π − m ) mm = msin(\pi -m) mm=msin(π−m)。所以这里的zy_keep 实际上是 cos ( t ) − m sin ( m ) \cos(t) - m\sin(m) cos(t)−msin(m)。这个地方的确是非常费解,我从insightface中的各种版本的源码中都进行查阅,大概的意思是,当 t 不在[0, π \pi π]范围内时,作者使用的cosface中的实现,但是cosface中的实现用的是 cos ( t ) − m \cos(t)-m cos(t)−m,而不是 cos ( t ) − sin ( m ) × m \cos(t) -\sin(m) \times m cos(t)−sin(m)×m,所以这里还是不太理解,希望以后有能看到某位大神的解释吧。
下面44行就是的new_zy就是选取满足条件的 cos ( t + m ) \cos(t+m) cos(t+m),如果没有超出范围,那么就保持原样,如果超出了,就用它cosface代替。
当然,上面的分析实际上感觉作者在这个地方有点多余,因为论文中作者已经分析过了,角度不会出现在150度以上,所以基本也不会出现越界的可能。所以在后面当loss_type=5的时候,作者也没有进行其他的约束,直接求
cos
(
t
)
\cos(t)
cos(t) 的反三角,然后进行常规的
cos
(
t
+
m
)
\cos(t+m)
cos(t+m),没有进行其他的约束。但不管怎么说,在训练arcfac的时候,作者是用了一些小trick的。
elif args.loss_type==5:
s = args.margin_s
m = args.margin_m
assert s>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
if args.margin_a==1.0 and args.margin_m==0.0:
s_m = s*args.margin_b
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t) #这里直接反三角
if args.margin_a!=1.0:
t = t*args.margin_a
if args.margin_m>0.0:
t = t+args.margin_m
body = mx.sym.cos(t)
if args.margin_b>0.0:
body = body - args.margin_b #这里直接计算cos(t+m)
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body