TensorFlow制作、读取TFRecord格式数据集

本文详细介绍了TensorFlow推荐的数据格式TFRecord的制作与读取方法,并提供了完整的代码示例,包括如何将图片转换为TFRecord格式,如何从TFRecord文件中读取数据,并展示了如何分批读取数据用于CNN训练。
摘要由CSDN通过智能技术生成

TFRecord数据格式是TensorFlow官方推荐的数据格式,不仅规范化读写,而且提高了IO效率。

1.制作TFRecord数据

原始数据为下图所示,文件夹名为类别标号,文件夹中存放的是各个类的图片:


制作TFRecord的代码为:

import os
import tensorflow as tf
from PIL import Image
import numpy as np

def create_record(inpath,outpath,size): # inpath:原始数据路径 outpath:TFRecord文件输出路径 size:统一图片尺寸,如[256,256]
    writer = tf.python_io.TFRecordWriter(outpath)
    dirs = os.listdir(inpath)
    for index, name in enumerate(dirs):
        class_path = inpath +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((size[1],size[0]),Image.ANTIALIAS) #将图片按照指定的size进行缩放,统一图片大小,注意resize方法的参数将长度放在前,宽度放在后
            img_raw = img.tobytes() #将图片转化为原生bytes
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()

2.读取TFRecord文件

读取的代码为:

def read_and_decode(filename,size): # filename:TFRecord文件路径 size:图片尺寸,如[256,256,1],最后一位表示颜色通道,1表示为灰度图像
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, size) 
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 图像数据
    label = tf.cast(features['label'], tf.int32) # 标签
    return img, label

3.分批读取

做CNN训练时,通常将输入数据整理成batch分批训练,而不是一次性全部输入,下面这个方法对读取文件进一步进行了封装,能够分批读取数据。

def get_train_batch(filename,size,min_after_dequeue,batch_size,capacity):
    """
        # 获取批量的训练数据
        # filename TFRecord文件路径
        # size 图像大小
        # min_after_dequeue 队列中最小样本数,这个数字越大表示一个batch的数据越乱
        # batch_size 一个batch的大小
        # capacity 队列的大小,应该大于batch_size
    """
    image,label = read_and_decode(filename,size)
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
    return image_batch,label_batch

下面为不使用shuffle机制的批次获取方法,获取测试集的时候使用。

def get_test_batch(filename, size, batch_size, capacity):
    image, label = read_and_decode(filename, size)
    image_batch, label_batch = tf.train.batch(
        [image, label], batch_size=batch_size, capacity=capacity)
    return image_batch, label_batch

4.读取TFRecord文件并保持为图片

def read_and_save(inpath,outpath,size,num): # inpath:TFRecord文件路径 outpath:输出的图片路径 size:图片大小,如[256,256] num:图片个数
    filename_queue = tf.train.string_input_producer([inpath]) 
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   
    features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, size)
    label = tf.cast(features['label'], tf.int32)
    with tf.Session() as sess: 
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        coord=tf.train.Coordinator()
        threads= tf.train.start_queue_runners(coord=coord)
        for i in range(num):
            example, l = sess.run([image,label])
            img=Image.fromarray(example)
            img.save(outpath+'/'+str(i)+'_Label_'+str(l)+'.png') # 保存图片
        coord.request_stop()
        coord.join(threads)


评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值