TFRecord 的写入和读取(序列化和反序列化)


TFrecord 是 tensorflow 中的数据集存储格式,它的写入和读取相当于序列化和反序列化的过程。

下面内容都是基于 tf2 版本来说明,两个版本中 tfrecord 的核心并没有改变,即序列化和反序列化。但是个人认为 tf2 里面对于读取 tfrecords 文件建立 Dataset 直接用于训练的支持更方便好用了,比如 batch、shuffle、repeat 等方面,所以在这点上基本摒弃了 tf1。

写入/序列化

(1)将数据读到内存,并转换为 tf.train.Example 对象,每个对象由若干个 tf.train.Feature 的字典组成。
(2)将 tf.train.Example 对象序列化为字符串,写入 TFRecord 文件。

读取/反序列化

(1)通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件,获得一个 tf.data.Dataset 的数据集对象。(tf1 中是要创建一个 reader 来读取 tfrecords 文件中的样例)
(2)通过 Dataset.map 对数据集对象中每个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 实现反序列化。
*:map 过程中,无法在 parse 内部进行某些处理,只能 parse 之后在 dataset 中迭代器“拿”出数据之后进行一些转换。

tf.train.Feature

上面多次提到的 tf.train.Feature 支持 3 种数据格式,因此对于各种各样的数据,必须处理成对应这三种格式,才能顺利写入/读取。

3 种格式如下:
tf.train.BytesList: 字符串或原始 Byte 文件,通过 bytes_list 传入。以图片或者数组等类型数据为例,需要转为字符串类型的再传入 BytesList,后面会有例子。
tf.train.FloatList:浮点数,通过 float_list 传入。
tf.train.Int64List:整数,通过 int64_list 传入。

具体实现

#-*-coding:utf-8-*-
import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2

# write
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

img_path = 'path/to/read/img'
tfrd_path = 'path/to/save/tfrecords'

image = cv2.imread(img_path)
h, w, c = image.shape
image_raw_data = gfile.FastGFile(img_path, 'rb').read()
label = 1

tfrd_writer = tf.io.TFRecordWriter(tfrd_path)
feature = {'image': _bytes_feature(image_raw_data),
           'label': _int64_feature(label),
           'imgH': _int64_feature(h),
           'imgW': _int64_feature(w),
           'imgC': _int64_feature(c)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
tfrd_writer.write(example.SerializeToString())
tfrd_writer.close()

# read
raw_dataset = tf.data.TFRecordDataset(tfrd_path)
feature_description = {'image': tf.io.FixedLenFeature([], tf.string),
                       'label': tf.io.FixedLenFeature([], tf.int64),
                       'imgH': tf.io.FixedLenFeature([], tf.int64),
                       'imgW': tf.io.FixedLenFeature([], tf.int64),
                       'imgC': tf.io.FixedLenFeature([], tf.int64)
                       }

def parse(record):
    features = tf.io.parse_single_example(record, feature_description)
    images = tf.io.decode_jpeg(features['image'])
    labels = features['label']
    imgH, imgW, imgC = features['imgH'], features['imgW'], features['imgC']
    shape = [imgH, imgW, imgC]
    return images, labels, shape

dataset = raw_dataset.map(parse)

for image, label, shape in dataset:
    print(label)
    cv2.imshow('img', image.numpy())
    cv2.waitKey()

如果遇到需要将 numpy 写入 tfrecord,可以先将 numpy 转为字符串,然后写入;读取的时候再转为 numpy 即可,注意 dtype 的对应。

import numpy as np
gt_row_np = np.array([0, 0, 1, 0], dtype=np.uint8)
gt_row_str = gt_row_np.tostring()
gt_row = np.frombuffer(gt_row_str, dtype=np.uint8)
print('gt_row_np type: {}, dtype: {}, value: {}'.format(type(gt_row_np), gt_row_np.dtype, gt_row_np))
print('gt_row_str type: {}, value: {}'.format(type(gt_row_str), gt_row_str))
print('gt_row type: {}, dtype: {}, value: {}'.format(type(gt_row), gt_row.dtype, gt_row))

# output
gt_row_np type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0]
gt_row_str type: <class 'bytes'>, value: b'\x00\x00\x01\x00'
gt_row type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0]

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值