(二)marginface的loss解读
margin face其实是arcface论文中融合了一下几个margin做了这么个实验,下面是截取的margin损失函数的代码,做了部分解读。
elif args.loss_type==5:#margin face
s = args.margin_s
m = args.margin_m
assert s>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s#注意这里乘以s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
if args.margin_a==1.0 and args.margin_m==0.0:
#简单的cosface不需要求arcsin
s_m = s*args.margin_b
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
#pick
#[[1 2 3],[3,4,5]] 特征向量,按照label[0,1,0]进行pick,axis=1代表最后输出为一个列,也就是每一行pick一下,得到[1,4,3]
#
cos_t = zy/s
#得到cos值
t = mx.sym.arccos(cos_t)
#反三角函数值
if args.margin_a!=1.0:
t = t*args.margin_a#m1 sphereface
if args.margin_m>0.0:
t = t+args.margin_m #m2 arcface
body = mx.sym.cos(t)
if args.margin_b>0.0: #m3 cosface
body = body - args.margin_b
new_zy = body*s
diff = new_zy - zy
#先剪一下fc7后面会加回来。这里只有label对应的变化。方便后面加回来。
diff = mx.sym.expand_dims(diff, 1)
#扩展维度,原来是[1,3,5]变为[[1],[3],[5]]
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
#one hot 原[2,1,0]
#[[0,0,1],[0,1,0],[1,0,0]]
body = mx.sym.broadcast_mul(gt_one_hot, diff)
#gt_one_hot[[0,0,1],[0,1,0],[1,0,0]]
#diff[[1],[3],[5]]
#结果:[[0,0,1],[0,3,0],[5,0,0]]
fc7 = fc7+body
#加上原fc7
(三)ArcFace解读
这里单独他有个arcface比较不好懂。简单说他在考虑cos函数是否单调的问题,我们给角度加一个margin,加完了得保持单调。他做了个margin
1、0<t+m<pi也就是cos-m<cost<cospi-m;完事他就搞了个threshold阈值,搞了个阈值,然后下面如果是简单的margin就不做了,如果是复杂的,我们就做一下处理,让cos_t - threshold会得到正负,后面会用他判断单调否。
2、接下来正常计算cost+m = costcost - sinmsinm
3、记下来到new_zy = mx.sym.where(cond, new_zy, zy_keep)关键语句;如果cond为真,就zy_keep否则new_zy
4、这个zy_keep不太懂,也不是cosface;??
elif args.loss_type==4:# arc face有一些小技巧...不太好懂
s = args.margin_s
m = args.margin_m
assert s>0.0
assert m>=0.0
assert m<(math.pi/2)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
#threshold = 0.0
threshold = math.cos(math.pi-m)
#搞了个阈值,然后下面如果是简单的margin就不做了,如果是复杂的,我们就做一下处理,让cos_t - threshold会得到正负,后面会用他判断单调否
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
#正常计算cost+m = costcost - sinmsinm
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.easy_margin:
zy_keep = zy
else:
zy_keep = zy - s*mm
#关键语句在下面,如果cond为真,就zy_keep否则new_zy
new_zy = mx.sym.where(cond, new_zy, zy_keep)
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body