Tensorflow—TFRecord文件生成与读取

Tensorflow—TFRecord文件生成与读取

微信公众号:幼儿园的学霸
个人的学习笔记,关于OpenCV,关于机器学习, …。问题或建议,请公众号留言;

目录

一.为什么使用TFRecord

关于 tensorflow 读取数据, 官网提供了3种方法:

  • Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据。
  • Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中。
  • 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况。

我们在刚学习Tensorflow时,几乎所有的例子都是使用第一种或第三种方法,因为例子中的数据量都比较少,而当数据量比较大时,由于这些文件被散列存着,这样不仅占用磁盘空间,并且在被一个个读取的时候会非常慢,繁琐,占用大量内存空间(有的大型数据不足以一次性加载),效率比较低。此时,第二种方法就会发挥巨大的作用,因此它存储的是二进制文件,PC读取二进制文件是比读取格式文件要快的多。

TFRecords是TensorFlow中的设计的一种内置的文件格式,它是一种二进制文件。其具有以下优点:

  • 统一不同输入文件的框架。
  • 它是更好的利用内存,更方便复制和移动。TFRecord压缩的二进制文件采用protocal buffer序列化,只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
  • 是用于将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

二.TFRecord文件生成

在将其他数据生成为TFRecords文件存储的时候,需要经过两个步骤:

  • 建立TFRecord生成器(存储器)
  • 构造每个样本的Example模块

1.TFRecord生成器

writer = tf.python_io.TFRecordWriter(record_path)
#for :
    writer.write(tf_example.SerializeToString())
#...
writer.close()

此处的writer就是我们的TFRecord生成器,输出参数record_path为我们将要生成的TFRecord文件的存储路径。
构建完毕TFRecord文件生成器后就可以调用生成器的write()方法向文件中写入一个字符串记录(即一个样本),不断的调用该方法以将每一个样本存储于生成器中,最后调用close()函数来关闭文件的写操作。
其中writer.write()的参数为一个序列化的Example,通过Example.SerializeToString()来实现,它的作用是将Example中的map压缩为二进制,节约大量空间。而Example是通过Example模块生成的。

2.Example模块

首先们来看一下Example协议块是什么样子的。

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

从定义中可以看出tf.train.Example是以字典的形式存储数据格式,string为字典的key值,字典的属性值有三种类型:bytes、float、int64。详解如下:
(1)tf.train.Example(features = None)

  • 写入tfrecords文件
  • features : tf.train.Features类型的特征实例
  • return : example协议格式块

(2)tf.train.Features(feature = None)

  • 构造每个样本的信息键值对
  • feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
  • return : Features类型

(3)tf.train.Feature(**options)
options可以选择如下三种格式数据:

  • bytes_list = tf.train.BytesList(value = [Bytes])
  • int64_list = tf.train.Int64List(value = [Value])
  • float_list = tf.trian.FloatList(value = [Value])
    那我们如何构造一个tf_example呢?下面有一个简单的例子
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    if not isinstance(value, collections.Iterable):
       value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.tobytes()]))
    
tf_example = tf.train.Example(#key-value形式
        features=tf.train.Features(
            feature={
                'image/image': bytes_feature(image),
                'image/shape': int64_list_feature(list(image.shape)),
                "bbox/xmins": int64_list_feature(xmins),
                "bbox/ymins": int64_list_feature(ymins),
                "bbox/xmaxs": int64_list_feature(xmaxs),
                "bbox/ymaxs": int64_list_feature(ymaxs),
                'image/classes': int64_list_feature(classes),
            }
        ))    

3.生成TFRecord文件完整代码实例

代码及图片路径:https://github.com/leonardohaig/yolov3_tensorflow/blob/master/generate_tfrecord.py

1)准备图片文件夹存放图片,此处我采用了小浣熊数据集
2)准备标签文件,文件格式如下:

xxx/xxx.jpg 18.19,6.32,424.13,421.83,20 323.86,2.65,640.0,421.94,20 
xxx/xxx.jpg 48,240,195,371,11 8,12,352,498,14
# image_path x_min, y_min, x_max, y_max, class_id  x_min, y_min ,..., class_id 

每一行表示图像路径,矩形框的左上顶点、右下顶点坐标,该矩形框类别 矩形框的左上顶点、右下顶点坐标,该矩形框类别 …
这两份文件分别为训练集和验证集。在该项目中已制作好,位于data/classes文件夹中,分别为data/classes/train_yoloTF.txttest_yoloTF.txt
3)生成TFRecord文件,generate_tfrecord.py

import os
import collections
import sys
import cv2

import tensorflow as tf

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    if not isinstance(value, collections.Iterable):
       value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.tobytes()]))


def create_tf_example(annotation):

    '''
    创建一条tf_example格式的数据
    :param annotation:list类型,一行label标签,内容:图片路径,目标位置,类别,....
    :return:
    '''

    line = annotation.split()

    image_path = line[0]
    assert os.path.exists(image_path),'{} not exist !'.format(image_path)

    xmins = []
    ymins = []
    xmaxs = []
    ymaxs = []
    classes = []
    for content in line[1:]:
        content = list(map(int,content.split(','))) #将其转换为int list
        xmins.append(content[0])
        ymins.append(content[1])
        xmaxs.append(content[2])
        ymaxs.append(content[3])
        classes.append(content[4])

    image = cv2.imread(image_path,cv2.IMREAD_UNCHANGED)
    image = cv2.resize(image, (413, 413), interpolation=cv2.INTER_LINEAR)

    tf_example = tf.train.Example(#key-value形式
        features=tf.train.Features(
            feature={
                'image/image': bytes_feature(image),
                'image/shape': int64_list_feature(list(image.shape)),
                "bbox/xmins": int64_list_feature(xmins),
                "bbox/ymins": int64_list_feature(ymins),
                "bbox/xmaxs": int64_list_feature(xmaxs),
                "bbox/ymaxs": int64_list_feature(ymaxs),
                'image/classes': int64_list_feature(classes),
            }
        ))

    #print(tf_example)

    return tf_example

def generate_tfrecord(labelFile, recordPath):
    '''

    :param labelFile: label file 文件路径
    :param recordPath: 创建的TFRecord文件存储路径
    :return:
    '''

    file_dir = os.path.dirname(os.path.abspath(recordPath))# 获取当前文件所在目录的绝对路径
    assert os.path.exists(file_dir),'{} not exist !'.format(file_dir)

    with open(labelFile,'r') as file:
        # writer = tf.python_io.TFRecordWriter(recordPath)
        writer = tf.io.TFRecordWriter(recordPath)
        for line in file.readlines():
            # annotation = line.split('\n') # 去除末尾的'\n'
            tf_example = create_tf_example(line)
            writer.write(tf_example.SerializeToString())
        writer.close()

    return True
    
    
if __name__ == '__main__':
    # 生成TFRecords文件
    generate_tfrecord('/home/liheng/PycharmProjects/yolov3_tensorflow/data/classes/test_yoloTF.txt',
                      './test.tfrecord')

Note:大多数情况下图片进行encode编码保存在tfrecord时 是一个一维张量,shape为(1,),因此有必要将尺寸信息保存下来,以便于恢复图片

三.TFRecord文件读取

1.基本流程

文件读取和文件创建的流程基本相同,只是中间多了一步解析过程。
1)将TFRecord文件test.record文件读入到文件队列中,如下所示:

filename_queue = tf.train.string_input_producer([tfrecords_filename])

使用tf.train.string_input_producer生成一个输入文件队列。这里我们的输入列表文件只有一个[path],而如果当训练数据比较大时,就需要将数据拆分多个TFRecord文件来提高处理效率。
例如,Cifar10的例子中,将训练集数据拆分为5个bin文件以提高文件处理效率,Cifar10例子使用下面方式获取所有的训练集输入文件列表,而Tensorflow既然让我们将训练数据拆分为多个TFRecord文件,那么它也提供函数tf.train.match_filenames_once,通过正则表达式获取某个目录下的输入文件列表。

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
filenames =tf.train.match_filenames_once('data_batch_×')

2)通过TFRecordReader读入生成的文件队列

reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue) #返回文件名和文件

3)通过解析器tf.parse_single_example将我们的example解析出来
当然,也可以采用tf.parse_example来解析,和tf.parse_single_example区别在于后者解析的是单个example.

2.代码示例

代码路径:https://github.com/leonardohaig/yolov3_tensorflow/blob/master/generate_tfrecord.py

def read_tfrecord(batchsize, recordFileList):
    '''
    从TFRecords文件当中读取图片数据(解析example)
    :param batchsize:
    :param recordFileList: TFRecord file文件列表,list类型
    :return:
    '''

    assert isinstance(recordFileList, collections.Iterable),'param recordFileList need type list!'

    # 1.构造文件队列
    filename_queue = tf.train.string_input_producer(recordFileList,num_epochs=None, shuffle=True)  # 参数为文件名列表

    # 2.构造阅读器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回文件名和文件

    # 3.批处理,此处批处理提前放置
    batch = tf.train.shuffle_batch([serialized_example],batch_size=batchsize, capacity=batchsize * 5, min_after_dequeue=batchsize * 2,num_threads=1)

    # 4.解析协议块,返回的值是字典.采用tf.parse_example,其返回的Tensor具有batch的维度
    _feature = {'image/image': tf.io.FixedLenFeature([], tf.string),
                'image/shape': tf.io.FixedLenFeature([3], dtype=tf.int64),
                'bbox/xmins': tf.io.VarLenFeature(dtype=tf.int64),
                'bbox/ymins': tf.io.VarLenFeature(dtype=tf.int64),
                'bbox/xmaxs': tf.io.VarLenFeature(dtype=tf.int64),
                'bbox/ymaxs': tf.io.VarLenFeature(dtype=tf.int64),
                'image/classes': tf.io.VarLenFeature(dtype=tf.int64)}
    features = tf.io.parse_example(batch,features=_feature)

    # 得到图片shape信息
    image_shape = features['image/shape']

    # 处理图片数据,由于是一个string,要进行解码,  #将字节转换为数字向量表示,字节为一字符串类型的张量
    # 如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型
    # decode_raw()可以将数据从string,bytes转换为int,float类型的
    image_raw = features['image/image']# Get the image as raw bytes.
    image_tensor = tf.decode_raw(image_raw, tf.uint8)# Decode the raw bytes so it becomes a tensor with type.
    # 转换图片的形状,此处需要用动态形状进行转换
    image_tensor = tf.reshape(image_tensor,shape=[batchsize,image_shape[0][0],image_shape[0][1],image_shape[0][2]])
    image_tensor = tf.image.convert_image_dtype(image_tensor,
                                                dtype=tf.float32)  # The type is now uint8 but we need it to be float.

    bbox_xmins = features['bbox/xmins']
    bbox_ymins = features['bbox/ymins']
    bbox_xmaxs = features['bbox/xmaxs']
    bbox_ymaxs = features['bbox/ymaxs']
    bbox_classes = features['image/classes']
    bbox_classes = tf.cast(bbox_classes,dtype=tf.int32)

    bbox_xmins = tf.sparse.to_dense(bbox_xmins)
    bbox_ymins = tf.sparse.to_dense(bbox_ymins)
    bbox_xmaxs = tf.sparse.to_dense(bbox_xmaxs)
    bbox_ymaxs = tf.sparse.to_dense(bbox_ymaxs)
    bbox_classes = tf.sparse.to_dense(bbox_classes)


    return image_tensor,bbox_xmins,bbox_ymins,bbox_xmaxs,bbox_ymaxs,bbox_classes

if __name__ == '__main__':

    # # 生成TFRecords文件
    # generate_tfrecord('/home/liheng/PycharmProjects/yolov3_tensorflow/data/classes/test_yoloTF.txt',
    #                   './test.tfrecord')


    # 从已经存储的TFRecords文件中解析出原始数据
    image_tensor, bbox_xmins, bbox_ymins, bbox_xmaxs, bbox_ymaxs, bbox_classes = read_tfrecord(4,['./test.tfrecord'])
    with tf.compat.v1.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())

        # 线程协调器
        coord = tf.train.Coordinator()
        # 开启线程
        thread = tf.train.start_queue_runners(sess, coord)

        for i in range(5):
            _image_tensor, _bbox_xmins, _bbox_ymins, _bbox_xmaxs,\
            _bbox_ymaxs, _bbox_classes = sess.run([image_tensor,bbox_xmins,
                                                   bbox_ymins,bbox_xmaxs,bbox_ymaxs,bbox_classes])


            print(i,_image_tensor.shape)
            #print(_bbox_xmins)

            cv2.imshow('image0', _image_tensor[0])
            cv2.imshow('image1', _image_tensor[1])
            cv2.waitKey(0)
            cv2.destroyAllWindows()




        # 回收线程
        coord.request_stop()
        coord.join(thread)

参考资料

1.Tensorflow(一) TFRecord生成与读取.
2.TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
3.Tensorflow针对不定尺寸的图片读写tfrecord文件总结



下面的是我的公众号二维码图片,欢迎关注。
图注:幼儿园的学霸

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Python读取TFRecord文件的方法如下: ```python import tensorflow as tf # 创建一个TFRecordDataset对象 dataset = tf.data.TFRecordDataset('data.tfrecord') # 定义读取函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 应用读取函数到每个record dataset = dataset.map(parser) # 创建迭代器 iterator = dataset.make_one_shot_iterator() # 获取数据 image, label = iterator.get_next() ``` 以上代码演示了如何读取名为`data.tfrecord`的TFRecord文件,并解析其中的图像和标签信息。在解析函数`parser`中,我们先定义了TFRecord文件中包含的特征信息,然后使用`tf.io.parse_single_example`函数解析单个record,并对图像数据进行解码。最后,我们使用`map`函数将解析函数应用到每个record上。 当然,如果您使用的是PyTorch,也可以使用以下代码读取TFRecord文件: ```python import torch import torchvision.datasets as datasets import torchvision.transforms as transforms # 定义解析函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 创建数据集对象 dataset = datasets.DatasetFolder( 'data.tfrecord', loader=lambda x: torch.load(x), extensions=('tfrecord') ) # 应用解析函数到每个record dataset.transform = transforms.Compose([ parser ]) # 创建数据加载器 dataloader = torch.utils.data.DataLoader( dataset, batch_size=32, shuffle=True ) # 获取数据 for images, labels in dataloader: # 使用数据进行训练或预测 pass ``` 以上代码演示了如何使用PyTorch的`DatasetFolder`读取TFRecord文件,并使用解析函数`parser`解析图像和标签信息。最后,我们创建了一个数据加载器,并使用其中的数据进行训练或预测。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值