insightface人脸识别代码记录(二)(数据处理)

一、前言

这部分主要围绕insightface目录下~/src/image_iter.py进行记录。其实,src目录下好多和前面目录重复的文件,好像是作者最开始是基于此目录进行训练的吧。
目录地址:insightface人脸识别代码记录(总)(基于MXNet)

二、主要内容

结合此脚本下的FaceImageIter类来进行记录MXNet中关于数据处理的一般形式,主要记录此类下面的__init__,next(),next_sample(),其他方法捎带简单解释。

可以看到此类继承的是mxnet.io.DataIter类,而此类正是MXNet中的构造数据迭代器的基础类。主要重写的是此类下的next方法。因为在MXNet中调用mxnet.io.DataIter接口,需要传送入数据(self.getdata())、标签(self.getlabel())、pad方式(self.getpad())和index信息(self.getindex()),而一般next方法就是将这几个数据封装到一起。

1.__init__:首先,读取.rec文件的路径,然后通过recordio.MXIndexedRecordIO接口来进行读取,然后读取idx为0的数据,这个idx对应.idx文件下的id为0的元素(这里见下面关于InsightFace Record格式),然后调用recordio脚本里的unpack方法,返回一个IRHeader,这个是关于图像记录的头文件,具体介绍见附录Fig1。然后根据header的flag的不同,进行不同的处理,获得图片的唯一标识idx,然后存放于self.imgidx。然后提下这个self.seq和shuffle,(因为后续会用到这个数据的含义,所以解释下)这个是区别MXNet数据读取方式的标志,详细见附录Fig2。剩下的就是一些数据增强的操作了。

2.next_sample():根据self.seq,来获得不同数据处理情况下的label和img。

3.next():获取next_sample()得到的label和img,并对其进行一系列数据增强操作,然后利用mxnet.io.DataBatch进行封装处理。

最后提下,imdecoderead_image,postprocess_data
imdecode是因为recordio脚本下unpack获得的是图像编码形式,需要用此函数下的mx.image.imdecode将其转化为ndarray格式,便于后续处理;
read_image就是处理只有.lst文件情况下的图像读取方式;
postprocess_data是对所获的图像做一个维度变化操作,原因是imdecode转化后的ndarray格式是(h,w,c)格式的,和cv2.imread()所获的格式一样,需要转化为(c,h,w)格式。

image_iter.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import random
import logging
...
...


class FaceImageIter(io.DataIter):

    def __init__(self, batch_size, data_shape,
                 path_imgrec = None,
                 shuffle=False, aug_list=None, mean = None,
                 rand_mirror = False, cutoff = 0, color_jittering = 0,
                 images_filter = 0,
                 data_name='data', label_name='softmax_label', **kwargs):
        super(FaceImageIter, self).__init__()
        assert path_imgrec
        if path_imgrec:
            logging.info('loading recordio %s...',
                         path_imgrec)
            path_imgidx = path_imgrec[0:-4]+".idx"
            self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type
            
            s = self.imgrec.read_idx(0)
            header, _ = recordio.unpack(s)
            
            if header.flag>0:
              print('header0 label', header.label)
              self.header0 = (int(header.label[0]), int(header.label[1]))
              #assert(header.flag==1)
              #self.imgidx = range(1, int(header.label[0]))
              self.imgidx = []
              self.id2range = {}
              self.seq_identity = range(int(header.label[0]), int(header.label[1]))
              for identity in self.seq_identity:
                s = self.imgrec.read_idx(identity)
                header, _ = recordio.unpack(s)
                a,b = int(header.label[0]), int(header.label[1])
                count = b-a
                if count<images_filter:
                  continue
                self.id2range[identity] = (a,b)
                self.imgidx += range(a, b)
              print('id2range', len(self.id2range))
            else:
              self.imgidx = list(self.imgrec.keys)
            if shuffle:
              self.seq = self.imgidx
              self.oseq = self.imgidx
              print(len(self.seq))
            else:
              self.seq = None

        self.mean = mean
        self.nd_mean = None
        if self.mean:
          self.mean = np.array(self.mean, dtype=np.float32).reshape(1,1,3)
          self.nd_mean = mx.nd.array(self.mean).reshape((1,1,3))

        self.check_data_shape(data_shape)
        self.provide_data = [(data_name, (batch_size,) + data_shape)]
        self.batch_size = batch_size
        self.data_shape = data_shape
        self.shuffle = shuffle
        self.image_size = '%d,%d'%(data_shape[1],data_shape[2])
        self.rand_mirror = rand_mirror
        print('rand_mirror', rand_mirror)
        self.cutoff = cutoff
        self.color_jittering = color_jittering
        self.CJA = mx.image.ColorJitterAug(0.125, 0.125, 0.125)
        self.provide_label = [(label_name, (batch_size,))]
        #print(self.provide_label[0][1])
        self.cur = 0
        self.nbatch = 0
        self.is_init = False


    def reset(self):...

    def num_samples(self):
      return len(self.seq)

    def next_sample(self):
        """Helper function for reading in next sample."""
        #set total batch size, for example, 1800, and maximum size for each people, for example 45
        
        if self.seq is not None:
          while True:
            if self.cur >= len(self.seq):
                raise StopIteration
            idx = self.seq[self.cur]
            self.cur += 1
            #有.rec文件和.idx文件的情况
            if self.imgrec is not None:
              s = self.imgrec.read_idx(idx)
              header, img = recordio.unpack(s)
              label = header.label
              if not isinstance(label, numbers.Number):
                label = label[0]
              return label, img, None, None
            #只有.ist文件的情况
            else:
              label, fname, bbox, landmark = self.imglist[idx]
              return label, self.read_image(fname), bbox, landmark
   		#只有.rec文件的情况
        else:
            s = self.imgrec.read()
            if s is None:
                raise StopIteration
            header, img = recordio.unpack(s)
            return header.label, img, None, None

    def brightness_aug(self, src, x):...
      
    def contrast_aug(self, src, x):...
      
    def saturation_aug(self, src, x):...
      
    def color_aug(self, img, x):...
      
    def mirror_aug(self, img):...
    
    def compress_aug(self, img):...
      
    def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                _data = self.imdecode(s)
                
                if _data.shape[0]!=self.data_shape[1]:...                  
                if self.rand_mirror:...                 
                if self.color_jittering>0:...               
                if self.nd_mean is not None:... 
                if self.cutoff>0:...
                  
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i)

    def check_data_shape(self, data_shape):...
        
    def check_valid_image(self, data):...
        
    def imdecode(self, s):
        """Decodes a string or byte string to an NDArray.
        See mx.img.imdecode for more details."""
        img = mx.image.imdecode(s) #mx.ndarray
        return img

    def read_image(self, fname):
        """Reads an input image `fname` and returns the decoded raw bytes.

        Example usage:
        ----------
        >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
        """
        with open(os.path.join(self.path_root, fname), 'rb') as fin:
            img = fin.read()
        return img

    def augmentation_transform(self, data):
        """Transforms input data with specified augmentation."""
        for aug in self.auglist:
            data = [ret for src in data for ret in aug(src)]
        return data

    def postprocess_data(self, datum):
        """Final postprocessing step before image is loaded into the batch."""
        return nd.transpose(datum, axes=(2, 0, 1))

class FaceImageIterList(io.DataIter):...

附录:

Fig1:
recordio脚本地址:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/recordio.py
以下截图均来自于recordio脚本。
在这里插入图片描述
从上图我们可以看到IRHeader有4个属性:flag,label,id,id2。
flag:分两种情况,flag为0和不为0。为0,代表只有单个数据,这时label为一个数字,即创建lst时所创建的数字,代表数据所属类别,此时id2也就为0;若flag不为0,则flag代表label的size,此时,label不再是一个数字,而是一个array,存储这此数据的长度(推测,如果有对这个含义清楚的,希望指正),此时id2不再为0,而是起到前面id标识的作用,因为此时数据大于1个,id自然无法唯一标识。
而flag大于1的情况下,label的值好像来自于对应header的所获图像编码。推测依据如下:
在这里插入图片描述
在这里插入图片描述
Fig2:
mxnet.image.ImageIter接口的记录。
mxnet.image.ImageIter接口继承自MXNet框架下的基础数据迭代器构造类mxnet.io.DataIter,该接口是python代码实现的图像数据迭代器,既可读取.rec文件,也可以以图像+.lst方式来读取数据。
mxnet.image.ImageIter类地址:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/image/image.py

class ImageIter(io.DataIter):
	def __init__(...):
		...
		#处理.rec文件格式
		if path_imgrec:
	            logging.info('%s: loading recordio %s...',
	                         class_name, path_imgrec)
	            #存在.idx文件和.rec文件的情况
	            if path_imgidx:
	                self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
	                self.imgidx = list(self.imgrec.keys)
	            #只有.rec文件的情况
	            else:
	                self.imgrec = recordio.MXRecordIO(path_imgrec, 'r')
	                self.imgidx = None
       	else:
	            self.imgrec = None
	
	        array_fn = _mx_np.array if is_np_array() else nd.array
	    #.lst文件+原图像的情况
       	if path_imglist:
           logging.info('%s: loading image list %s...', class_name, path_imglist)
           with open(path_imglist) as fin:
               imglist = {}
               imgkeys = []
               for line in iter(fin.readline, ''):
                   line = line.strip().split('\t')
                   label = array_fn(line[1:-1], dtype=dtype)
                   key = int(line[0])
                   imglist[key] = (label, line[-1])
                   imgkeys.append(key)
               self.imglist = imglist
       	elif isinstance(imglist, list):
           logging.info('%s: loading image list...', class_name)
           result = {}
           imgkeys = []
           index = 1
           for img in imglist:
               key = str(index)
               index += 1
               if len(img) > 2:
                   label = array_fn(img[:-1], dtype=dtype)
               elif isinstance(img[0], numeric_types):
                   label = array_fn([img[0]], dtype=dtype)
               else:
                   label = array_fn(img[0], dtype=dtype)
               result[key] = (label, img[-1])
               imgkeys.append(str(key))
           self.imglist = result
       else:
           self.imglist = None
      
        ...
        ...
#根据imgkeys和self.imgidx可推测如下情况。

        #.lst文件+原图像的情况
        if self.imgrec is None:
            self.seq = imgkeys
        #存在.idx文件和.rec文件的情况
        elif shuffle or num_parts > 1 or path_imgidx:
            assert self.imgidx is not None
            self.seq = self.imgidx
        #只有.rec文件的情况
        else:
            self.seq = None

结尾

差不多到这里就结束了。在这里记录下insightface数据处理的方法,同时也学习了MXNet框架。
另外,记一下另外两个经常用到的数据处理的接口:
图像分类:mxnet.io.ImageRecordIter()
目标检测:mxnet.io.ImageDetRecordIter()

InsightFace Record格式:
key = 0 , value_header => [identities_key_start, identities_key_end]

key∈[1, identities_key_start), value_header => [identity_label],value_content => [face_image]

key∈[identities_key_start, identities_key_end), value_header => [identity_key_start, identity_key_end]

参考

MXNet源码解读:数据读取高级类(2)— mxnet.image.ImageIter
MXNet源码解读:数据读取基础类—mxnet.io.DataIter
InsightFace - 使用篇, 如何一键刷分LFW 99.80%, MegaFace 98%.

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值