insightface中损失函数loss的记录

未完待续。。。。

一、前言

这里主要结合insightface中片段代码对人脸识别常用的一些loss进行记录。
代码地址:

二、主要内容

get_symbol()

 def get_symbol(args, arg_params, aux_params):
	data_shape = (args.image_channel,args.image_h,args.image_w)
	  image_shape = ",".join([str(x) for x in data_shape])
	  margin_symbols = []
	  if args.network[0]=='d':
	    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
	        version_se=args.version_se, version_input=args.version_input, 
	        version_output=args.version_output, version_unit=args.version_unit)
	  elif args.network[0]=='m':
	    print('init mobilenet', args.num_layers)
	    if args.num_layers==1:
	      embedding = fmobilenet.get_symbol(args.emb_size, 
	          version_input=args.version_input, 
	          version_output=args.version_output,
	          version_multiplier = args.version_multiplier)
	    else:
	      embedding = fmobilenetv2.get_symbol(args.emb_size)
	  elif args.network[0]=='i':
	    print('init inception-resnet-v2', args.num_layers)
	    embedding = finception_resnet_v2.get_symbol(args.emb_size,
	        version_se=args.version_se, version_input=args.version_input, 
	        version_output=args.version_output, version_unit=args.version_unit)
	  elif args.network[0]=='x':
	    print('init xception', args.num_layers)
	    embedding = fxception.get_symbol(args.emb_size,
	        version_se=args.version_se, version_input=args.version_input, 
	        version_output=args.version_output, version_unit=args.version_unit)
	  elif args.network[0]=='p':
	    print('init dpn', args.num_layers)
	    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
	        version_se=args.version_se, version_input=args.version_input, 
	        version_output=args.version_output, version_unit=args.version_unit)
	  elif args.network[0]=='n':
	    print('init nasnet', args.num_layers)
	    embedding = fnasnet.get_symbol(args.emb_size)
	  elif args.network[0]=='s':
	    print('init spherenet', args.num_layers)
	    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
	  elif args.network[0]=='y':
	    print('init mobilefacenet', args.num_layers)
	    embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, version_output=args.version_output)
	  else:
	    print('init resnet', args.num_layers)
	    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
	        version_se=args.version_se, version_input=args.version_input, 
	        version_output=args.version_output, version_unit=args.version_unit,
	        version_act=args.version_act)
	
	  all_label = mx.symbol.Variable('softmax_label')
	  gt_label = all_label
	  extra_loss = None
	  _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult)
	  
	  if args.loss_type==0: #softmax
	    if args.fc7_no_bias:
	      fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    else:
	      _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
	      fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
	  
	  elif args.loss_type==1: #sphere
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
	                          weight = _weight,
	                          beta=args.beta, margin=args.margin, scale=args.scale,
	                          beta_min=args.beta_min, verbose=1000, name='fc7')
	  elif args.loss_type==2:
	    s = args.margin_s
	    m = args.margin_m
	    assert(s>0.0)
	    assert(m>0.0)
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
	    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    s_m = s*m
	    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
	    fc7 = fc7-gt_one_hot
	
	  elif args.loss_type==4:
	    s = args.margin_s
	    m = args.margin_m
	    assert s>0.0
	    assert m>=0.0
	    assert m<(math.pi/2)
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
	    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    zy = mx.sym.pick(fc7, gt_label, axis=1)
	    cos_t = zy/s
	    cos_m = math.cos(m)
	    sin_m = math.sin(m)
	    mm = math.sin(math.pi-m)*m
	    #threshold = 0.0
	    threshold = math.cos(math.pi-m)   #小于0
	    if args.easy_margin:
	      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
	    else:
	      cond_v = cos_t - threshold
	      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
	
	    body = cos_t*cos_t
	    body = 1.0-body
	    sin_t = mx.sym.sqrt(body)
	    new_zy = cos_t*cos_m
	    b = sin_t*sin_m
	    new_zy = new_zy - b
	    new_zy = new_zy*s
	
	    if args.easy_margin:
	      zy_keep = zy
	    else:
	      zy_keep = zy - s*mm
	    new_zy = mx.sym.where(cond, new_zy, zy_keep)
	
	    diff = new_zy - zy
	    diff = mx.sym.expand_dims(diff, 1)
	    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
	    body = mx.sym.broadcast_mul(gt_one_hot, diff)
	    fc7 = fc7+body
	
	  elif args.loss_type==5:
	    s = args.margin_s
	    m = args.margin_m
	    assert s>0.0
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
	    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
	      if args.margin_a==1.0 and args.margin_m==0.0:
	        s_m = s*args.margin_b
	        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
	        fc7 = fc7-gt_one_hot
	      else:
	        zy = mx.sym.pick(fc7, gt_label, axis=1)
	        cos_t = zy/s
	        t = mx.sym.arccos(cos_t)
	        if args.margin_a!=1.0:
	          t = t*args.margin_a
	        if args.margin_m>0.0:
	          t = t+args.margin_m
	        body = mx.sym.cos(t)
	        if args.margin_b>0.0:
	          body = body - args.margin_b
	        new_zy = body*s
	        diff = new_zy - zy
	        diff = mx.sym.expand_dims(diff, 1)
	        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
	        body = mx.sym.broadcast_mul(gt_one_hot, diff)
	        fc7 = fc7+body
	
	  elif args.loss_type==6:
	    s = args.margin_s
	    m = args.margin_m
	    assert s>0.0
	    assert args.margin_b>0.0
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
	    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    zy = mx.sym.pick(fc7, gt_label, axis=1)
	    cos_t = zy/s
	    t = mx.sym.arccos(cos_t)
	    intra_loss = t/np.pi
	    intra_loss = mx.sym.mean(intra_loss)
	    #intra_loss = mx.sym.exp(cos_t*-1.0)
	    intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b)
	    if m>0.0:
	      t = t+m
	      body = mx.sym.cos(t)
	      new_zy = body*s
	      diff = new_zy - zy
	      diff = mx.sym.expand_dims(diff, 1)
	      gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
	      body = mx.sym.broadcast_mul(gt_one_hot, diff)
	      fc7 = fc7+body
	
	  elif args.loss_type==7:
	    s = args.margin_s
	    m = args.margin_m
	    assert s>0.0
	    assert args.margin_b>0.0
	    assert args.margin_a>0.0
	    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
	    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
	    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
	    zy = mx.sym.pick(fc7, gt_label, axis=1)
	    cos_t = zy/s
	    t = mx.sym.arccos(cos_t)
	
	    #counter_weight = mx.sym.take(_weight, gt_label, axis=1)
	    #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
	    counter_weight = mx.sym.take(_weight, gt_label, axis=0)
	    counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
	    #counter_cos = mx.sym.minimum(counter_cos, 1.0)
	    #counter_angle = mx.sym.arccos(counter_cos)
	    #counter_angle = counter_angle * -1.0
	    #counter_angle = counter_angle/np.pi #[0,1]
	    #inter_loss = mx.sym.exp(counter_angle)
	
	    #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
	    #counter_cos = mx.sym.minimum(counter_cos, 1.0)
	    #counter_angle = mx.sym.arccos(counter_cos)
	    #counter_angle = mx.sym.sort(counter_angle, axis=1)
	    #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))
	
	    #inter_loss = counter_angle*-1.0 # [-1,0]
	    #inter_loss = inter_loss+1.0 # [0,1]
	    inter_loss = counter_cos
	    inter_loss = mx.sym.mean(inter_loss)
	    inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = args.margin_b)
	    if m>0.0:
	      t = t+m
	      body = mx.sym.cos(t)
	      new_zy = body*s
	      diff = new_zy - zy
	      diff = mx.sym.expand_dims(diff, 1)
	      gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
	      body = mx.sym.broadcast_mul(gt_one_hot, diff)
	      fc7 = fc7+body
	
		  out_list = [mx.symbol.BlockGrad(embedding)]
		  softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
		  out_list.append(softmax)
		
		  if args.loss_type==6:
		    out_list.append(intra_loss)
		  if args.loss_type==7:
		    out_list.append(inter_loss)
		    #out_list.append(mx.sym.BlockGrad(counter_weight))
		    #out_list.append(intra_loss)
		  if args.ce_loss:
		    #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
		    body = mx.symbol.SoftmaxActivation(data=fc7)
		    body = mx.symbol.log(body)
		    _label = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = -1.0, off_value = 0.0)
		    body = body*_label
		    ce_loss = mx.symbol.sum(body)/args.per_batch_size
		    out_list.append(mx.symbol.BlockGrad(ce_loss))
		  out = mx.symbol.Group(out_list)
		
		  return (out, arg_params, aux_params)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值