insightface人脸识别代码记录(四) (训练代码)

一、前言

这篇主要结合insightface的训练代码,来介绍下MXNet训练代码的一个整体结构,一举两得。
代码l路径是~/src/train_softmax.py
目录地址:insightface人脸识别代码记录(总)(基于MXNet)

二、主要内容

一开始,设置训练日志
然后,得到一个Module对象model,为后续调用fit()方法做准备。根据这个model的获得,一般训练分为以下几种情况:
①自己定义网络重头训练
②迁移学习,利用别人训练好的模型进行微调
③断点训练,继续自己的模型进行训练。
最后,通过调用modelfit()函数进入MNXet训练代码的入口。所以,前面的过程就是为这个函数准备各种参数。主要包括以下几个参数:

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )

train_data:训练数据

	...
	path_imgrec = os.path.join(data_dir, "train.rec") 
	data_shape = (args.image_channel,image_size[0],image_size[1])
	...
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
        color_jittering      = args.color,
        images_filter        = args.images_filter,
    )

eval_data: 测试数据,设置为空

begin_epoch,num_epoch: 开始周期和周期个数

begin_epoch = 0
end_epoch = args.end_epoch

eval_metric:评价指标

metric1 = AccMetric()
eval_metrics = [mx.metric.create(metric1)]

if args.ce_loss:
  metric2 = LossValueMetric()
  eval_metrics.append( mx.metric.create(metric2) )

optimizer,optimizer_params:优化函数设置

opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
#另一种写法:
# lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=step_bs,factor=args.factor)
# optimizer_params = {'learning_rate': 0.001,
#                        'momentum': 0.9,
#                        'wd': 0.0001,
#                        'lr_scheduler': lr_scheduler,
#                        'rescale_grad': 1.0/len(ctx) if len(ctx)>0 else 1.0}

initializer:参数初始化

 initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
 # ①rnd_type: 默认‘uniform’/'gaussian'。
 # ②factor_type:默认‘avg’/'in'/'out'
 # ③magnitude:随机数的倍数
 # 鉴于篇幅,不展开记录,可自行百度

arg_params,aux_params:网络层的(辅助)参数信息

if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
      if args.network[0]=='s':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
else:
  vec = args.pretrained.split(',')
  print('loading', vec)
  _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
  sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

epoch_end_callback:模型保存路径

# 模型保存路劲设置,一般可通过下面语句设置:
# epoch_callback = mx.callback.do_checkpoint(prefix,period)
# prefix:模型保存名称     period:模型保存周期
 epoch_cb = None

batch_end_callback:日志显示的批次间隔

_cb = mx.callback.Speedometer(args.batch_size, som)
_batch_callback():...
# mx.callback.Speedometer(batch_size,frequent)
# 函数原型。frequent:每训练frequent个批次就显示一次相关信息

下面,我们围绕这几个参数的获得来看整个训练代码,将会简单许多。具体记录位于以下代码中。

train_softmax.py

from __future__ import absolute_import
...

import os
import sys
import math
...
...

# 训练日志的设置
logger = logging.getLogger()
logger.setLevel(logging.INFO)

args = None
# 评价函数,这个的具体见前面博客
class AccMetric(mx.metric.EvalMetric):...
  
class LossValueMetric(mx.metric.EvalMetric):...

# 设置一系列所需要的参数   
# 注意设置训练数据的路径
def parse_args():...
  
#这个单独讲解,见下
def get_symbol(args, arg_params, aux_params):...
 
 
def train_net(args):

# 选择gpu / cpu 训练
    ctx = []
   	cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in range(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    
# 保存模型的路径设置
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
# 一系列参数预设
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    #读取property文件
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    #image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    # 获取训练数据路径
    path_imgrec = os.path.join(data_dir, "train.rec")    
	# 根据loss不同,设置参数
    if args.loss_type==1 and args.num_classes>20000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None
# 预训练模型是否存在
    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    # 预训练模型不存在
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
      if args.network[0]=='s':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    # 预训练模型存在或是断点训练
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)


# 初始化model

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context       = ctx,
        symbol        = sym,
    )
    val_dataiter = None
# 获取train_data参数
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
        color_jittering      = args.color,
        images_filter        = args.images_filter,
    )
# 获取eval_metric参数
    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]

    if args.ce_loss:
      metric2 = LossValueMetric()
      eval_metrics.append( mx.metric.create(metric2) )


# 根据net类型进行参数初始化  

    if args.network[0]=='r' or args.network[0]=='y':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0/args.ctx_num
# 优化函数设置
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
# 设置日志显示的批次间隔
    som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

# 加载测试集数据
    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


# 对测试集进行测试
    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

# lr_steps的设置
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in range(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
# 模型保存和lr等一些参数的变化设置
	def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]

      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:        #args.beta_freeze = 5000
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:      #args.verbose = 2000
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:    
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:

            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score

            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        # 模型保存方式
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta                 #args.beta = 1000
      else:                               #mbatch>args.beta_freeze
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

        
 # 模型保存路劲设置
    epoch_cb = None
	# mx.io.PrefetchingIter()这个好像是把几个数据迭代器合并的接口
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
 # 训练入口
    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )

def main():
    global args
    args = parse_args()
    train_net(args)

if __name__ == '__main__':
    main()

三、结尾

到这里训练部分就算结束了。限于篇幅,关于训练代码中get_symbol()函数见如下链接:

get_symbol()函数:insightface中损失函数loss的记录

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值