tensorflow tf.data 生成以及读取TFRecord

一、将生成record文件,以图片为例

#!/usr/bin/env python
# -*- coding:utf-8 -*- 
#Author:  1477517404@qq.com

import tensorflow as tf
from PIL import Image
import os
import io

def int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

def process_image_channels(image):
    process_flag = False
    print(image.mode)
    if image.mode == 'RGBA':
        r,g,b,a = image.split()
        image = Image.merge("RGB",(r,g,b))
        process_flag = True
    elif image.mode != 'RGB':
        image = image.convert('RGB')
        process_flag = True

    print('process_flag is : ',process_flag)
    return image,process_flag

def create_tf_example(image_path,label,resize=None):
    with tf.io.gfile.GFile(image_path,'rb') as fid:
        encode_jpg = fid.read()
    encode_jpg_io = io.BytesIO(encode_jpg)
    image = Image.open(encode_jpg_io)
    image,flag = process_image_channels(image)
    # if flag == True:
    bytes_io = io.BytesIO()
    image.save(bytes_io,format='JPEG')
    encode_jpg = bytes_io.getvalue()
    print(len(encode_jpg))
    width,height = image.size
    tf_example = tf.train.Example(
        features = tf.train.Features(
            feature = {
                "image/encoded":bytes_feature(encode_jpg),
                'image/format':bytes_feature(b'jpg'),
                'image/class/label':int64_feature(1),
                'image/height':int64_feature(height),
                'image/width':int64_feature(width)
            }
        )
    )
    return  tf_example


def generate_tfrecord(annotation_dict,record_path,resize=None):
    num_tf_example = 0
    writer = tf.io.TFRecordWriter(record_path)
    for image_path,lable in annotation_dict.items():
        if not tf.gfile.GFile(image_path):
            print("{} does not exist".format(image_path))
        tf_example = create_tf_example(image_path,lable,resize)

        writer.write(tf_example.SerializeToString())
        num_tf_example += 1
        if num_tf_example % 2 == 0:
            print("Create %d TF_example" % num_tf_example)
    writer.close()

def get_annotation_dict(image_dir):
    annotation_dict = {}
    filelist = os.listdir(image_dir)
    for image in filelist:
        annotation_dict[image_dir+image] = 1  # 方便起见
    return annotation_dict

def main():
    image_dir = 'data/'
    record_path = 'data/image.record'
    annotation_dict = get_annotation_dict(image_dir)
    generate_tfrecord(annotation_dict,record_path)

if __name__ == "__main__":
    main()

二、tf.data 读取生成的文件

#!/usr/bin/env python
# -*- coding:utf-8 -*- 
#Author:1477517404@qq.com

import tensorflow as tf
import multiprocessing as mt
from PIL import Image
import matplotlib.pyplot as plt

def parser(record):
    features = {
        'image/encoded': tf.FixedLenFeature((), default_value='', dtype=tf.string),
        'image/format': tf.FixedLenFeature((), default_value='jpg', dtype=tf.string),
        'image/class/label': tf.FixedLenFeature([], default_value=0, dtype=tf.int64),
        'image/height': tf.FixedLenFeature([], default_value=0, dtype=tf.int64),
        'image/width': tf.FixedLenFeature([], default_value=0, dtype=tf.int64)

    }

    example = tf.parse_single_example(record,features)
    width = example['image/width']
    image = tf.reshape(tf.image.decode_jpeg(example['image/encoded']),(width,width,3))
    label = example['image/class/label']
    return image,label



def read_tfRecord():
    dataset = tf.data.TFRecordDataset(['data/image.record'])  # 读取record文件
    dataset = dataset.map(parser)                              # 使用解析方法对数据进行解析
    dataset = dataset.shuffle(buffer_size=100).batch(5)         # 打乱 buffer_size 要大于样本数才能保证充分打乱,获取批次
    dataset = dataset.repeat(5)                                 # 将数据重复 num_epoches
    return dataset

def main():
    dataset = read_tfRecord()
    print('shapes:', dataset.output_shapes)
    print('types:', dataset.output_types)
    next_op = dataset.make_one_shot_iterator().get_next()
    with tf.Session() as sess:
        for i in range(40):
            print('--------------------------batch({})---------------------'.format(i))
            try:
                batch_image, batch_label = sess.run(next_op)
                image = batch_image[0]

                # 显示一下图片查看是否有问题
                image = Image.fromarray(image)
                plt.figure()
                plt.imshow(image)
                plt.show()


                print(batch_image.shape)
                print(batch_label,batch_label.shape,batch_label[0])
            except tf.errors.OutOfRangeError:
                print("队列已经遍历完成!")
                break

if __name__ == '__main__':
    main()


程序结果:
在这里插入图片描述

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值