未完待续。。。。
一、前言
这里主要结合insightface中片段代码对人脸识别常用的一些loss进行记录。
代码地址:
二、主要内容
get_symbol()
:
def get_symbol(args, arg_params, aux_params):
data_shape = (args.image_channel,args.image_h,args.image_w)
image_shape = ",".join([str(x) for x in data_shape])
margin_symbols = []
if args.network[0]=='d':
embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='m':
print('init mobilenet', args.num_layers)
if args.num_layers==1:
embedding = fmobilenet.get_symbol(args.emb_size,
version_input=args.version_input,
version_output=args.version_output,
version_multiplier = args.version_multiplier)
else:
embedding = fmobilenetv2.get_symbol(args.emb_size)
elif args.network[0]=='i':
print('init inception-resnet-v2', args.num_layers)
embedding = finception_resnet_v2.get_symbol(args.emb_size,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='x':
print('init xception', args.num_layers)
embedding = fxception.get_symbol(args.emb_size,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='p':
print('init dpn', args.num_layers)
embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
elif args.network[0]=='n':
print('init nasnet', args.num_layers)
embedding = fnasnet.get_symbol(args.emb_size)
elif args.network[0]=='s':
print('init spherenet', args.num_layers)
embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
elif args.network[0]=='y':
print('init mobilefacenet', args.num_layers)
embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, version_output=args.version_output)
else:
print('init resnet', args.num_layers)
embedding = fresnet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit,
version_act=args.version_act)
all_label = mx.symbol.Variable('softmax_label')
gt_label = all_label
extra_loss = None
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
if args.loss_type==0: #softmax
if args.fc7_no_bias:
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
else:
_bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
elif args.loss_type==1: #sphere
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
weight = _weight,
beta=args.beta, margin=args.margin, scale=args.scale,
beta_min=args.beta_min, verbose=1000, name='fc7')
elif args.loss_type==2:
s = args.margin_s
m = args.margin_m
assert(s>0.0)
assert(m>0.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
s_m = s*m
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
elif args.loss_type==4:
s = args.margin_s
m = args.margin_m
assert s>0.0
assert m>=0.0
assert m<(math.pi/2)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
#threshold = 0.0
threshold = math.cos(math.pi-m) #小于0
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.easy_margin:
zy_keep = zy
else:
zy_keep = zy - s*mm
new_zy = mx.sym.where(cond, new_zy, zy_keep)
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==5:
s = args.margin_s
m = args.margin_m
assert s>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
if args.margin_a==1.0 and args.margin_m==0.0:
s_m = s*args.margin_b
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
fc7 = fc7-gt_one_hot
else:
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
if args.margin_a!=1.0:
t = t*args.margin_a
if args.margin_m>0.0:
t = t+args.margin_m
body = mx.sym.cos(t)
if args.margin_b>0.0:
body = body - args.margin_b
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==6:
s = args.margin_s
m = args.margin_m
assert s>0.0
assert args.margin_b>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
intra_loss = t/np.pi
intra_loss = mx.sym.mean(intra_loss)
#intra_loss = mx.sym.exp(cos_t*-1.0)
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
elif args.loss_type==7:
s = args.margin_s
m = args.margin_m
assert s>0.0
assert args.margin_b>0.0
assert args.margin_a>0.0
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
t = mx.sym.arccos(cos_t)
#counter_weight = mx.sym.take(_weight, gt_label, axis=1)
#counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
counter_weight = mx.sym.take(_weight, gt_label, axis=0)
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
#counter_angle = mx.sym.arccos(counter_cos)
#counter_angle = counter_angle * -1.0
#counter_angle = counter_angle/np.pi #[0,1]
#inter_loss = mx.sym.exp(counter_angle)
#counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
#counter_cos = mx.sym.minimum(counter_cos, 1.0)
#counter_angle = mx.sym.arccos(counter_cos)
#counter_angle = mx.sym.sort(counter_angle, axis=1)
#counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))
#inter_loss = counter_angle*-1.0 # [-1,0]
#inter_loss = inter_loss+1.0 # [0,1]
inter_loss = counter_cos
inter_loss = mx.sym.mean(inter_loss)
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = args.margin_b)
if m>0.0:
t = t+m
body = mx.sym.cos(t)
new_zy = body*s
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
body = mx.sym.broadcast_mul(gt_one_hot, diff)
fc7 = fc7+body
out_list = [mx.symbol.BlockGrad(embedding)]
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
out_list.append(softmax)
if args.loss_type==6:
out_list.append(intra_loss)
if args.loss_type==7:
out_list.append(inter_loss)
#out_list.append(mx.sym.BlockGrad(counter_weight))
#out_list.append(intra_loss)
if args.ce_loss:
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
body = mx.symbol.SoftmaxActivation(data=fc7)
body = mx.symbol.log(body)
_label = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = -1.0, off_value = 0.0)
body = body*_label
ce_loss = mx.symbol.sum(body)/args.per_batch_size
out_list.append(mx.symbol.BlockGrad(ce_loss))
out = mx.symbol.Group(out_list)
return (out, arg_params, aux_params)