【tensorflow】MTCNN网络Pnet数据转化为tfrecord文件

将人脸与关键点数据合并后文件转化为tfrecord文件。

# coding:utf-8
import os
import random
import sys
import tensorflow as tf
import cv2
def _convert_to_example_simple(image_example, image_buffer):
    """
    covert to tfrecord file
    :param image_example: dict, an image example
    :param image_buffer: string, JPEG encoding of RGB image
    :param colorspace:
    :param channels:
    :param image_format:
    :return:
    Example proto
    """
    # filename = str(image_example['filename'])
    # class label for the whole image
    class_label = image_example['label']  # 传入label值(1、0、-1、-2)
    bbox = image_example['bbox']  # 传入bbox
    roi = [bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']]  # 传入坐标的4个值
    landmark = [bbox['xlefteye'], bbox['ylefteye'], bbox['xrighteye'], bbox['yrighteye'], bbox['xnose'], bbox['ynose'],
                bbox['xleftmouth'], bbox['yleftmouth'], bbox['xrightmouth'], bbox['yrightmouth']]
    # 传入landmark的10个值
    # example
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': _bytes_feature(image_buffer),  # 图片矩阵的字符串形式
        'image/label': _int64_feature(class_label),  # label
        'image/roi': _float_feature(roi),  # 人脸框4个坐标
        'image/landmark': _float_feature(landmark)  # landmark的10个坐标
    }))
    return example

def _float_feature(value):
    """Wrapper for insert float features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))



def _int64_feature(value):
    """Wrapper for insert int64 feature into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
    """Wrapper for insert bytes features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _process_image_withoutcoder(imgfile):
    # transform data into string format
    image = cv2.imread(imgfile)
    image_data = image.tostring()   		  #构建一个包含ndarray的原始字节数据的字节字符串
    assert len(image.shape) == 3         	  #判断图片的格式是否为(高、宽、通道数)
    height = image.shape[0]				  #传入图片的高
    width = image.shape[1]				  #传入图片的宽
    assert image.shape[2] == 3			  #判断图片是否为RGB三色通道
    # return string data and initial height and width of the image
    print(image_data)
    return image_data, height, width        #返回字符串形式的数据和图片原始高宽

def _bytes_feature(value):
    """Wrapper for insert bytes features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _add_to_tfrecord(filename, image_example, tfrecord_writer):
    """Loads data from image and annotations files and add them to a TFRecord.

    Args:
      filename: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    # 从图片和注释文件里加载数据并将其添加到TFRecord里
    # filename: 数据目录
    # image_example: 数据,为字典形式,里面有三个key
    # tfrecord_writer:with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer
    # print('---', filename)

    # imaga_data:array to string
    # height:original image's height
    # width:original image's width
    # image_example dict contains image's info
    # image_data:转化成了字符串的图片
    # height:图片原始高度
    # width:图片原始宽度
    # image_example字典包含图片的信息
    print(filename)
    image_data, height, width = _process_image_withoutcoder(filename)
    example = _convert_to_example_simple(image_example, image_data)
    tfrecord_writer.write(example.SerializeToString())
    # TFRecord制作结束


def _get_output_filename(output_dir, name, net):  # 获得一个输出的文件名
    # st = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    # return '%s/%s_%s_%s.tfrecord' % (output_dir, name, net, st)
    return '%s/train_PNet_landmark.tfrecord' % (
        output_dir)  # 返回的是'../../DATA/imglists/PNet/train_PNet_landmark.tfrecord'


def run(dataset_dir, net, output_dir, name='MTCNN', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: The dataset directory where the dataset is stored.
      output_dir: Output directory.
    """

    # tfrecord name
    tf_filename = _get_output_filename(output_dir, name, net)  # '../../DATA/imglists/PNet/train_PNet_landmark.tfrecord'
    if tf.gfile.Exists(tf_filename):  # 判断是否存在同名文件
        print('Dataset files already exist. Exiting without re-creating them.')
        return
    # GET Dataset, and shuffling.
    dataset = get_dataset(dataset_dir, net=net)  # 列表dataset
    # filenames = dataset['filename']
    if shuffling:
        tf_filename = tf_filename + '_shuffle'  # shuffling=True
        # random.seed(12345454)
        random.shuffle(dataset)  # 打乱dataset的顺序
    # Process dataset files.
    # write the data to tfrecord
    print('lala')  # 打印'lala'
    with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
        for i, image_example in enumerate(dataset):  # 读取dataset的索引和内容
            if (i + 1) % 100 == 0:
                sys.stdout.write('\r>> %d/%d images has been converted' % (i + 1, len(dataset)))
                # sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(dataset)))
            sys.stdout.flush()  # 刷新输出
            filename = image_example['filename']  # 赋值
            _add_to_tfrecord(filename, image_example, tfrecord_writer)
            # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    print('\nFinished converting the MTCNN dataset!')


def get_dataset(dir, net='PNet'):
    # get file name , label and anotation
    # item = 'imglists/PNet/train_%s_raw.txt' % net
    # 获取文件名字,标签和注释
    # item =  'imglists/PNet/train_PNet_landmark.txt'
    item = 'imglists/PNet/train_%s_landmark.txt' % net

    dataset_dir = os.path.join(dir, item)  # dataset_dir = '../../DATA/imglists/PNet/train_PNet_landmark.txt'
    # print(dataset_dir)
    imagelist = open(dataset_dir, 'r')  # 以只读的形式打开train_PNet_landmark.txt,并传入imagelist里面

    dataset = []  # 新建列表
    for line in imagelist.readlines():  # 读取imagelist里面的内容
        info = line.strip().split(' ')  # 去除每一行首尾的空格并且以空格为分隔符读取内容到info里面
        data_example = dict()  # 新建字典
        bbox = dict()
        data_example['filename'] = info[0]  # filename=info[0]
        # print(data_example['filename'])
        data_example['label'] = int(info[1])  # label=info[1],info[1]的值有四种可能,1,0,-1,-2;分别对应着正、负、无关、关键点样本。
        bbox['xmin'] = 0  # 初始化bounding box的值
        bbox['ymin'] = 0
        bbox['xmax'] = 0
        bbox['ymax'] = 0
        bbox['xlefteye'] = 0  # 初始化人脸坐标的值
        bbox['ylefteye'] = 0
        bbox['xrighteye'] = 0
        bbox['yrighteye'] = 0
        bbox['xnose'] = 0
        bbox['ynose'] = 0
        bbox['xleftmouth'] = 0
        bbox['yleftmouth'] = 0
        bbox['xrightmouth'] = 0
        bbox['yrightmouth'] = 0
        if len(info) == 6:  # info的长度等于6时,表示此时的info是正样本或者无关样本,详情请看学习记录(一)的文末
            bbox['xmin'] = float(info[2])
            bbox['ymin'] = float(info[3])
            bbox['xmax'] = float(info[4])
            bbox['ymax'] = float(info[5])
        if len(info) == 12:  # info长度等于12时,表示此时的info是landmark样本
            bbox['xlefteye'] = float(info[2])
            bbox['ylefteye'] = float(info[3])
            bbox['xrighteye'] = float(info[4])
            bbox['yrighteye'] = float(info[5])
            bbox['xnose'] = float(info[6])
            bbox['ynose'] = float(info[7])
            bbox['xleftmouth'] = float(info[8])
            bbox['yleftmouth'] = float(info[9])
            bbox['xrightmouth'] = float(info[10])
            bbox['yrightmouth'] = float(info[11])

        data_example['bbox'] = bbox  # 将bbox值传入字典
        dataset.append(data_example)  # 将字典data_example传入列表dataset

    return dataset  # 返回dataset,datase是个列表,但是里面的每个元素都是一个字典,每个字典有3个key,分别是filename、label和bbox。


if __name__ == '__main__':
    dir = 'E:/MTCNN'
    net = 'PNet'
    output_directory = 'E:/MTCNN/imglists/PNet'
    run(dir, net, output_directory, shuffling=True)

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

胖子工作室

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值