TensorFlow 之 TFRecord

原英文文章是:https://www.skcript.com/svr/why-every-tensorflow-developer-should-know-about-tfrecord/

为什么每个TensorFlow开发人员都应该了解TFRecord!

经过几天的Tensorflow,每个初学者都会遇到这个疯狂的Tensorflow文件格式,称为Tfrecords。大多数批处理操作不是直接从图像完成的,而是将它们转换为单个tfrecord文件(图像是numpy数组,标签是字符串列表)。了解这种转换的目的以及它对工作流程的真正好处始终是初学者的梦魇。所以在这里,我通过简单到复杂的例子更容易理解。

什么是TFRecord?

根据Tensorflow的文档,

“…方法是将您拥有的任何数据转换为支持的格式。这种方法可以更轻松地混合和匹配数据集和网络架构。TensorFlow的推荐格式是TFRecords文件,其中包含tf.train.Example协议缓冲区(其中包含功能作为字段)。“

所以,我建议维护可扩展架构和标准输入格式的更简单方法是将其转换为tfrecord文件。

让我用初学者的语言来解释,

因此,当您使用图像数据集时,您首先要做的是什么?分成训练,测试,验证集,对吧?如果存在像date这样的偏差参数,我们也会将其改组为没有任何有偏差的数据分布。

做文件夹结构然后保持洗牌是不是繁琐的工作?

如果所有内容都在一个文件中,并且我们可以使用该文件在随机位置动态混洗,并且还可以更改train:test:validate与整个数据集的比率。听起来有一半的工作量被删除了吗?现在不再是初学者维持不同分裂的噩梦了。这可以通过tfrecords实现。

让我们看看代码之间的区别 - Naive vs Tfrecord

Naive普通方式

import os 

import glob

import random

# Loading the location of all files - image dataset
# Considering our image dataset has apple or orange
# The images are named as apple01.jpg, apple02.jpg .. , orange01.jpg .. etc.

images = glob.glob('data/*.jpg')

# Shuffling the dataset to remove the bias - if present
random.shuffle(images)
# Creating Labels. Consider apple = 0 and orange = 1

labels = [ 0 if 'apple' in image else 1 for image in images ]
data = list(zip(images, labels))

# Ratio

data_size = len(data)
split_size = int(0.6 * data_size)

# Splitting the dataset

training_images, training_labels = zip(*data[:split_size])
testing_images, testing_labels = zip(*data[split_size:])

Tfrecord方式

按照这五个步骤操作,您将完成一个tfrecord文件,该文件包含您继续处理的所有数据。

  1. 使用tf.python_io.TFRecordWriter打开tfrecord文件,并开始写作。
  2. 在写入tfrecord文件之前,应将图像数据和标签数据转换为正确的数据类型。(byte,int,float)
  3. 现在数据类型被转换为 tf.train.Feature
  4. 最后创建一个Example Protocol Buffer使用tf.Example并使用转换后的功能。使用serialize()函数序列化示例。
  5. 写序列化Example。
import tensorflow as tf 

import numpy as np

import glob

from PIL import Image

# Converting the values into features
# _int64 is used for numeric values

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# _bytes is used for string/char values

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
tfrecord_filename = 'something.tfrecords'

# Initiating the writer and creating the tfrecords file.

writer = tf.python_io.TFRecordWriter(tfrecord_filename)

# Loading the location of all files - image dataset
# Considering our image dataset has apple or orange
# The images are named as apple01.jpg, apple02.jpg .. , orange01.jpg .. etc.

images = glob.glob('data/*.jpg')
for image in images[:1]:
  img = Image.open(image)
  img = np.array(img.resize((32,32)))
label = 0 if 'apple' in image else 1
feature = { 'label': _int64_feature(label),
              'image': _bytes_feature(img.tostring()) }

# Create an example protocol buffer

 example = tf.train.Example(features=tf.train.Features(feature=feature))

# Writing the serialized example.

 writer.write(example.SerializeToString())

writer.close()

如果您仔细查看所涉及的过程,则非常简单。

Data -> FeatureSet -> Example -> Serialized Example -> tfRecord.

所以要回读它,这个过程是相反的。

tfRecord -> SerializedExample -> Example -> FeatureSet -> Data

从tfrecord阅读

import tensorflow as tf 
import glob
reader = tf.TFRecordReader()
filenames = glob.glob('*.tfrecords')
filename_queue = tf.train.string_input_producer(
   filenames)
_, serialized_example = reader.read(filename_queue)
feature_set = { 'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenFeature([], tf.int64)
           }
           
features = tf.parse_single_example( serialized_example, features= feature_set )
label = features['label']
 
with tf.Session() as sess:
  print sess.run([image,label])

您也可以使用 tf.train.shuffle_batch()来随机播放文件

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值