生成和读取TFRecord文件

下载数据cifar-10数据集(形式为分文件夹存放原始图片),文件目录结构如下所示:

我们只选择演示其中的train文件夹中的图片。

生成TFRecord文件:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import glob
import tensorflow as tf

# 指定图片数据路径
PATH = 'cifar-10'
# 指定输出TFRecord文件
OUT_DIR = 'tfrecord_file'


# 以下函数将一个value转换为与tf适配的格式
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


# 利用一张图片创建一个example
def make_example(image_string, label):
    # 如果我们想获取图片的高宽和深度,则将读取到的数据先解码为jpeg图片格式,然后用shape取值
    # image_shape = tf.image.decode_jpeg(image_string).shape
    feature = {
        # 'height': _int64_feature(image_shape[0]),
        # 'width': _int64_feature(image_shape[1]),
        # 'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image_string)
    }

    exam = tf.train.Example(features=tf.train.Features(feature=feature))
    return exam


# 定义一个TFRecordWriter实例
write = tf.io.TFRecordWriter(OUT_DIR)

# 创建一个index:img_name的字典
index_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
cate_list = ['airplane', 'mobilephone', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img_dict = zip(index_list, cate_list)

# 循环所有类别的图片文件夹,生成tfrecord文件
for index, img_cate in img_dict:
    # 生成训练集每个图片的路径
    train_img_list = glob.glob(os.path.join(PATH, 'train/{}/*.jpg'.format(index)))
    # test_img_list = glob.glob(os.path.join(PATH, 'test/{}/*.jpg'.format(index)))

    print("start to make train file index = {}:".format(index))
    count = 0
    # 开始循环读取每张图片,然后连同label,一起打包到TFRecord文件中
    for per_img_dir in train_img_list:

        # 使用gfile.Gfile来读图片文件
        with tf.io.gfile.GFile(per_img_dir, 'rb') as per_img_fp:
            img_data = per_img_fp.read()
        # 为每一张图片+label创建一个example
        example = make_example(img_data, index)

        # 打印看看example具体内容结构
        if count == 0:
            for line in str(example).split('\n')[:15]:
                print(line)

        # 将每个example写入到write中
        write.write(example.SerializeToString())
        count += 1
        if count % 1000 == 0:
            print(count)

write.close()

其中每一个图片所生成的example结构如下:

features {
  feature {
    key: "image_raw"
    value {
      bytes_list {
        value: "\377\330\377\340\000\020JFIF\000\001\001\000\000\001\000\001\000\000\377\333\000C\000\010\006\006\007\006\005\010\007\007\007\t\t\010\n\014\024\r\014\013\013\014\032
    ......
2)g\230\266\025Q\to\312\276\252\360\037\206\233B\360\355\231\276\211N\252\321\177\244JynNB\223\355\300\374)F-=\004\225\217\377\331"
      }
    }
  }
  feature {
    key: "label"
    value {
      int64_list {
        value: 0
      }
    }
  }

其中包含我们再feature中指定的label和image_raw。

 

读取TFRecord中的数据:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值