TFRecords文件实现不定长图片和标签的存储和读取感悟(2)(更新版)

将不定长图片和标签生成TFRecords文件进行保存,前期是使用PIL模块进行图片的读取,详情见TFRecords文件实现不定长图片和标签的存储和读取感悟(1)(附完整代码),由于每次batch时要求图片的尺寸大小一致,所以就需要定义一个最大宽度(所有图片高度一定)max_width,需要对图片进行补零填充,此种方法怎么说呢,就是有点不讨喜,后来使用cv2模块进行图片的读取,从tfrecord二进制文件中读取图片时,使用下述函数,就能够恢复图片大小,不用设定图片的最大宽度值,在batch的时候,将dynamic_pad参数设定为True,可自动补零填充,对label亦如是。

image = tf.image.decode_jpeg(image_features['image'])
image.set_shape([32, None, 3])

完整代码如下: 

# coding=utf-8
import tensorflow as tf
import os
import cv2
import random
import json
import sys

NUM_EXAMPLES_PER_EPOCH = 20000
RATIO = 0.9

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def generation_TFRecord(data_dir, tfrecord_dir):
    '''
    生成TFRecord文件
    @param data_dir: 数据所在文件夹
    @return:
    '''
    vocublary = json.load(open("./map.json", "r"))
    image_name_list = []
    for file in os.listdir(data_dir):
        if file.endswith('.jpg'):
            image_name_list.append(file)

    random.shuffle(image_name_list)
    # capacity = len(image_name_list)

    # 生成train tfrecord文件
    train_writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_dir, 'train_dataset.tfrecord'))
    train_image_name_list = image_name_list[0: int(RATIO * NUM_EXAMPLES_PER_EPOCH)]
    for train_name in train_image_name_list:
        train_image_label = []
        for s in train_name.strip('.jpg'):
            train_image_label.append(vocublary[s])

        train_image_raw = cv2.imread(os.path.join(data_dir, train_name), 1)
        if train_image_raw is None:
            continue
        height, width, channel = train_image_raw.shape
        ratio = 32 / float(height)
        train_image = cv2.resize(train_image_raw, (int(width * ratio), 32))
        # 将图片格式转换(编码)成流数据,赋值到内存缓存中;主要用于图像数据格式的压缩,方便网络传输。
        is_success, train_image_buffer = cv2.imencode('.jpg', train_image)
        if not is_success:
            continue
        train_image_byte = train_image_buffer.tostring()

        train_example = tf.train.Example(features=tf.train.Features(feature={
            'label': int64_list_feature(train_image_label),
            'image': bytes_feature(train_image_byte)}))
        train_writer.write(train_example.SerializeToString())
        sys.stdout.flush()
    sys.stdout.flush()
    train_writer.close()

    # 生成test tfrecord文件
    test_writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_dir, 'test_dataset.tfrecord'))
    test_image_name_list = image_name_list[int(NUM_EXAMPLES_PER_EPOCH * RATIO):NUM_EXAMPLES_PER_EPOCH]
    for test_name in test_image_name_list:
        test_image_label = []
        for s in test_name.strip('.jpg'):
            test_image_label.append(vocublary[s])

        # 以彩色图像方式读取
        test_image_raw = cv2.imread(os.path.join(data_dir, test_name), 1)
        if test_image_raw is None:
            continue

        height, width, channel = test_image_raw.shape
        ratio = 32 / float(height)
        test_image = cv2.resize(test_image_raw, (int(width * ratio), 32))
        is_success, test_image_buffer = cv2.imencode('.jpg', test_image)
        if not is_success:
            continue
        test_image_byte = test_image_buffer.tostring()

        test_example = tf.train.Example(features=tf.train.Features(feature={
            'label': int64_list_feature(test_image_label),
            'image': bytes_feature(test_image_byte)}))
        test_writer.write(test_example.SerializeToString())
    test_writer.close()


def read_tfrecord(filename, batch_size, is_train=True):
    if not os.path.exists(filename):
        raise ValueError('connot find tfrecord file in path')

    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialize_example = reader.read(filename_queue)
    image_features = tf.parse_single_example(serialized=serialize_example,
                                             features={
                                                 'label': tf.VarLenFeature(dtype=tf.int64),
                                                 'image': tf.FixedLenFeature([], tf.string)
                                             })
    image = tf.image.decode_jpeg(image_features['image'])
    image.set_shape([32, None, 3])
    image = tf.cast(image, tf.float32)

    label = tf.cast(image_features['label'], tf.int32)
    sequence_length = tf.cast(tf.shape(image)[-2] / 4, tf.int32)

    if is_train is True:
        train_image_batch, train_label_batch, train_sequence_length = tf.train.batch([image, label, sequence_length],
                                                                                     batch_size=batch_size,
                                                                                     dynamic_pad=True,
                                                                                     num_threads=4,
                                                                                     capacity=1000 + 3 * batch_size)
        return train_image_batch, train_label_batch, train_sequence_length
    else:
        test_image_batch, test_label_batch, test_sequence_length = tf.train.batch([image, label, sequence_length],
                                                                                  batch_size=batch_size,
                                                                                  dynamic_pad=True,
                                                                                  capacity=1000 + 3 * batch_size)
        return test_image_batch, test_label_batch, test_sequence_length


def main(argv):
    # print(vocublary)
    data_dir = "/textGenation/english/Train_en_10000/Train"
    tfrecord_dir = '/textGenation/english/Train_en_10000/'
    generation_TFRecord(data_dir, tfrecord_dir)
    tfrecord_files = os.path.join(tfrecord_dir, 'train_dataset.tfrecord')
    train_image, train_label, train_seq_length = read_tfrecord(tfrecord_files, 32)
    dense_label = tf.sparse_tensor_to_dense(train_label)

    with tf.Session() as session:
        session.run(tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer()))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=session, coord=coord)
        for i in range(2):
            t_image, t_label, t_seq_len, t_dense_label = session.run([train_image, train_label,
                                                                      train_seq_length, dense_label])
            print(t_dense_label)
            print(t_seq_len)

        coord.request_stop()
        coord.join(threads=threads)


if __name__ == '__main__':
    tf.app.run()

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

马鹤宁

谢谢

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值