TFRecords数据格式的制作、读取及解析

 TFRecords 是tensorflow使用的一种二进制数据格式,它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。使用框架定义的数据格式好处是有强大的框架支持,例如封装了数据解析、多线程等操作, 使用起来很方便。

 TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。从TFRecords文件中读取数据可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。废话不多说了,接下来看代码吧。

1、将bmp,jpg,png等格式的数据制作成TFR格式

def convert2tfr(path,name):
    classes = 81   # 类别数目
    writer = tf.python_io.TFRecordWriter(name+'.tfrecords')  # 要生成的文件
    for index in range(classes):
        class_path = path + str(index) + '/'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址
            img = Image.open(img_path)
            img = img.convert("RGB") #转换成RGB格式
            img = img.resize((32,32))
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串
    writer.close()
2、 将tfr格式的数据解析为原来的图片格式

def tfr2bmp(filename,dir,num):  # tfr文件名,解析后存放的目录,图片数量
    if not os.path.exists('read_img'):
        os.mkdir('read_img')
    os.mkdir('read_img/'+dir)
    filename_queue = tf.train.string_input_producer([filename])  # 读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [32, 32,3])
    label = tf.cast(features['label'], tf.int32)
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(num):
            example, l = sess.run([image, label])  # 在会话中取出image和label
            img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
            img.save('read_img/'+dir+'/'+ str(i) + '_Label_' + str(l) + '.bmp')  # 存下图片
            #print(example, l)
        coord.request_stop()
        coord.join(threads)
3、tensorflow训练时读取操作

 def read_and_decode(filename,batch_size):
    filename_queue = tf.train.string_input_producer([filename])# create a queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#return image and label

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [32,32,3])  #reshape image to 32*32*3

    # 3通道转为1通道
    img =  tf.image.rgb_to_grayscale(img) #图像灰度化 32*32*1
    img = tf.reshape(img, [32,32])  #reshape image to 32*32

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
    label = tf.cast(features['label'], tf.int64) #throw label tensor
    img_batch, label_batch = tf.train.shuffle_batch([img,label], batch_size=batch_size,
                                        capacity=1000+batch_size*3, min_after_dequeue=1000)
    label_batch = tf.one_hot(label_batch, depth=81) # depth=类别数目

    return img_batch, label_batch
完整的代码如下:
import os
import tensorflow as tf
from PIL import Image

def convert2tfr(path,name):
    classes = 81   # 类别数目
    writer = tf.python_io.TFRecordWriter(name+'.tfrecords')  # 要生成的文件
    for index in range(classes):
        class_path = path + str(index) + '/'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每一个图片的地址
            img = Image.open(img_path)
            img = img.convert("RGB") #转换成RGB格式
            img = img.resize((32,32))
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串
    writer.close()


def read_and_decode(filename,batch_size):
    filename_queue = tf.train.string_input_producer([filename])# create a queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#return file_name and file
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#return image and label

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [32,32,3])  #reshape image to 32*32*3

    # 3通道转为1通道
    img =  tf.image.rgb_to_grayscale(img) #图像灰度化 32*32*1
    #img = tf.reshape(img, [32,32])  #reshape image to 32*32

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
    label = tf.cast(features['label'], tf.int64) #throw label tensor
    img_batch, label_batch = tf.train.shuffle_batch([img,label], batch_size=batch_size,
                                        capacity=1000+batch_size*3, min_after_dequeue=1000)
    label_batch = tf.one_hot(label_batch, depth=81)

    return img_batch, label_batch

def tfr2bmp(filename,dir,num):
    if not os.path.exists('read_img'):
        os.mkdir('read_img')
    os.mkdir('read_img/'+dir)
    filename_queue = tf.train.string_input_producer([filename])  # 读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 取出包含image和label的feature对象
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [32, 32,3])
    label = tf.cast(features['label'], tf.int32)
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(num):
            example, l = sess.run([image, label])  # 在会话中取出image和label
            img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
            img.save('read_img/'+dir+'/'+ str(i) + '_Label_' + str(l) + '.bmp')  # 存下图片
            #print(example, l)
        coord.request_stop()
        coord.join(threads)



if __name__ == '__main__':
    #将图片转换为TFR格式
    convert2tfr('train_img/','tai_train1')
    convert2tfr('test_img/','tai_test1')

    #读取TFR格式数据
    #read_and_decode('tai_test.tfrecords',batch_size)
    #read_and_decode('tai_train.tfrecords',batch_size)

    #提取TFR格式数据并保存
    tfr2bmp(filename='tai_test.tfrecords',dir='test_img',num=9834)
    tfr2bmp(filename='tai_train.tfrecords',dir='train_img',num=48961)


 

 

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值