insightface tripletloss源码阅读

                                               一:insighteface tripletloss实现中的dataiter部分解读

前言:

mxnet的dataiter一般为继承io.DataIter类,实现其中主要的几个函数。

需要实现的主要函数为:__init__(),reset(),next()。以及几个在fit函数中需要用到的几个属性:provide_label,provide_data等。

本文主要阅读了insightface代码中实现的triplet_image_iter.py。主要学习和解读其源代码。

代码解读:

1.首先需要实现的dataiter的reset()。源代码的reset()调用的主要函数为pick_triplets,select_triplets,pairwise_dists。

  • 选择三元组 pick_triplets

 函数:根据facenet的选取三元组的规则,尽量选择离anchor近的负样本。

 返回:返回所有的满足要求的三元组。

 python知识点:np.logical_and()  np.where()

具体注释:

## 根据facenet的选取要求选取三元组,返回三元组列表
    def pick_triplets(self, embeddings, nrof_images_per_class):
      emb_start_idx = 0
      triplets = []
      people_per_batch = len(nrof_images_per_class)#类别数
      #self.time_reset()
      pdists = self.pairwise_dists(embeddings)#由提取的特征列表计算样本间距离
      #self.times[3] += self.time_elapsed()

      for i in xrange(people_per_batch):
          nrof_images = int(nrof_images_per_class[i])#每一类的图片数
          for j in xrange(1,nrof_images):
              #self.time_reset()
              a_idx = emb_start_idx + j - 1
              #neg_dists_sqr = np.sum(np.square(embeddings[a_idx] - embeddings), 1)
              neg_dists_sqr = pdists[a_idx]
              #self.times[3] += self.time_elapsed()

              for pair in xrange(j, nrof_images): # For every possible positive pair.
                  p_idx = emb_start_idx + pair
                  #self.time_reset()
                  pos_dist_sqr = np.sum(np.square(embeddings[a_idx]-embeddings[p_idx]))#anchor 和 positive 之间的距离
                  #self.times[4] += self.time_elapsed()
                  #self.time_reset()
                  neg_dists_sqr[emb_start_idx:emb_start_idx+nrof_images] = np.NaN#将距离列表中正样本对距离置为无穷大,方便之后的选取负样本
                  if self.triplet_max_ap>0.0:
                    if pos_dist_sqr>self.triplet_max_ap:
                      continue
				  #np.where:返回满足条天的数组坐标,多维数组的时候返回多个列表
				  #np.logical_and:逻辑与,同时满足facenet选择条件(负样本和anchor的距离与正样本对距离之间差<alpha,正样本对距离小于负样本距离)
                  all_neg = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<self.triplet_alpha, pos_dist_sqr<neg_dists_sqr))[0]  # FaceNet selection
                  #self.times[5] += self.time_elapsed()
                  #self.time_reset()
                  #all_neg = np.where(neg_dists_sqr-pos_dist_sqr<alpha)[0] # VGG Face selecction
                  nrof_random_negs = all_neg.shape[0]
                  if nrof_random_negs>0:#随机选取一个满足条件的负样本
                      rnd_idx = np.random.randint(nrof_random_negs)
                      n_idx = all_neg[rnd_idx]
                      triplets.append( (a_idx, p_idx, n_idx) )
          emb_start_idx += nrof_images
	  #打乱顺序
      np.random.shuffle(triplets)
      return triplets
  • select_triplets :获取anchor ,p,n的主函数

    函数调用pick_triplets。

    输出:self.seq,其中内容是anchor_batch,p_batch,n_batch。

    主要步骤:根据triplet_seq系列初始化数组,将数组送入模型进行前向计算,根据计算结果重新选择三元组,最终将三元组的结果保存。

   

    def select_triplets(self):
      self.seq = []
      while len(self.seq)<self.seq_min_size:
        self.time_reset()
        embeddings = None
        bag_size = self.triplet_bag_size
        batch_size = self.batch_size
        #data = np.zeros( (bag_size,)+self.data_shape )
        #label = np.zeros( (bag_size,) )
        tag = []
        #idx = np.zeros( (bag_size,) )
        print('eval %d images..'%bag_size, self.triplet_cur)
        print('triplet time stat', self.times)
        if self.triplet_cur+bag_size>len(self.triplet_seq):
          self.triplet_reset()
          #bag_size = min(bag_size, len(self.triplet_seq))
          print('eval %d images..'%bag_size, self.triplet_cur)
        self.times[0] += self.time_elapsed()
        self.time_reset()
        #print(data.shape)
        data = nd.zeros( self.provide_data[0][1] )
        label = None
        if self.provide_label is not None:
          label = nd.zeros( self.provide_label[0][1] )
        ba = 0
		#从0-bag_size
        while True:
          bb = min(ba+batch_size, bag_size)
          if ba>=bb:
            break
          _count = bb-ba
          #data = nd.zeros( (_count,)+self.data_shape )
          #_batch = self.data_iter.next()
          #_data = _batch.data[0].asnumpy()
          #print(_data.shape)
          #_label = _batch.label[0].asnumpy()
          #data[ba:bb,:,:,:] = _data
          #label[ba:bb] = _label
          for i in xrange(ba, bb):
            #print(ba, bb, self.triplet_cur, i, len(self.triplet_seq))
            _idx = self.triplet_seq[i+self.triplet_cur]#triplet_reset中初始化
            s = self.imgrec.read_idx(_idx)
            header, img = recordio.unpack(s)
            img = self.imdecode(img)
            data[i-ba][:] = self.postprocess_data(img)
            _label = header.label
            if not isinstance(_label, numbers.Number):
              _label = _label[0]
            if label is not None:
              label[i-ba][:] = _label
            tag.append( ( int(_label), _idx) )
            #idx[i] = _idx

          db = mx.io.DataBatch(data=(data,))
		  ##前向计算当前batch
          self.mx_model.forward(db, is_train=False)
          net_out = self.mx_model.get_outputs()#获取前向的结果
          #print('eval for selecting triplets',ba,bb)
          #print(net_out)
          #print(len(net_out))
          #print(net_out[0].asnumpy())
          net_out = net_out[0].asnumpy()
          #print(net_out)
          #print('net_out', net_out.shape)
          if embeddings is None:
            embeddings = np.zeros( (bag_size, net_out.shape[1]))
          embeddings[ba:bb,:] = net_out
          ba = bb
        assert len(tag)==bag_size
        self.triplet_cur+=bag_size
        embeddings = sklearn.preprocessing.normalize(embeddings)
        self.times[1] += self.time_elapsed()
        self.time_reset()
		## 获取类别数和每个类别样本数
        nrof_images_per_class = [1]
        for i in xrange(1, bag_size):
          if tag[i][0]==tag[i-1][0]:#lable
            nrof_images_per_class[-1]+=1
          else:
            nrof_images_per_class.append(1)
        ## 选择三元组
        triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)
        print('found triplets', len(triplets))
        ba = 0
        while True:
          bb = ba+self.per_batch_size//3
          if bb>len(triplets):
            break
          _triplets = triplets[ba:bb]
          for i in xrange(3):
            for triplet in _triplets:
              _pos = triplet[i]
              _idx = tag[_pos][1]#idx
              self.seq.append(_idx)# a_batch p_batch n_batch 
          ba = bb
        self.times[2] += self.time_elapsed()

 

2.next()

 函数返回:databatch,datalabel。

def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))#batch_data
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)#编码s为一个ndarray
                if self.rand_mirror:#随机镜像
                  _rd = random.randint(0,1)
                  if _rd==1:
                    _data = mx.ndarray.flip(data=_data, axis=1)
                if self.cutoff>0:#截断
                  centerh = random.randint(0, _data.shape[0]-1)
                  centerw = random.randint(0, _data.shape[1]-1)
                  half = self.cutoff//2
                  starth = max(0, centerh-half)
                  endh = min(_data.shape[0], centerh+half)
                  startw = max(0, centerw-half)
                  endw = min(_data.shape[1], centerw+half)
                  _data = _data.astype('float32')
                  #print(starth, endh, startw, endw, _data.shape)
                  _data[starth:endh, startw:endw, :] = 127.5
                #_npdata = _data.asnumpy()
                #if landmark is not None:
                #  _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size)
                #if self.rand_mirror:
                #  _npdata = self.mirror_aug(_npdata)
                #if self.mean is not None:
                #  _npdata = _npdata.astype(np.float32)
                #  _npdata -= self.mean
                #  _npdata *= 0.0078125
                #nimg = np.zeros(_npdata.shape, dtype=np.float32)
                #nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :]
                #_data = mx.nd.array(nimg)
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    if self.provide_label is not None:
                      batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        #print('next end', batch_size, i)
        _label = None
        if self.provide_label is not None:
          _label = [batch_label]
        return io.DataBatch([batch_data], _label, batch_size - i)

                                            二:insighteface tripletloss实现中的loss部分

源码:train_triplet.py

nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
  #在第一个维度上切分三分,分别是anchor,positive,negative
  anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
  positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
  negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
  #tripletloss 实现
  ap = anchor - positive
  an = anchor - negative
  ap = ap*ap
  an = an*an
  ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
  an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
  triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
  triplet_loss = mx.symbol.mean(triplet_loss)
  #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
  triplet_loss = mx.symbol.MakeLoss(triplet_loss)
  #预测值和loss值的合并,将预测值堵塞反向传播
  out_list = [mx.symbol.BlockGrad(embedding)]
  out_list.append(mx.sym.BlockGrad(gt_label))
  out_list.append(triplet_loss)
  out = mx.symbol.Group(out_list)

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值