关于tf2.X的tfrecord创建与读取

主要参考资料(Tensroflow官方文档):https://tensorflow.google.cn/tutorials/load_data/image

TFRecord 文件是一种用来存储一串二进制 blob 的简单格式。通过将多个示例打包进同一个文件内,TensorFlow 能够一次性读取多个示例,当使用一个远程存储服务,如 GCS 时,这对性能来说尤其重要。

我的项目是一个二分类问题,标签使用了独热编码,为了后期直接把图片和标签数据zip到一起喂给网络,所以制作了图片和标签的tfrecord文件,以及读取和解析的函数。中间有很多坑,折腾了很久,特别烦人,现将代码贴出来,给有用的人提供参考。

如有其他问题,欢迎打扰!

"""Created on Thu Jul  1 15:47:53 2021
@author: John_Huang"""
import os
import sys
import glob
import time
import tensorflow as tf
import numpy
import random


def preprocess_image(_image):
    _image = tf.image.decode_jpeg(_image, channels=3)
    _image = tf.image.resize(_image, [256, 256])#具体参数根据自己网络模型调整
    _image /= 255.0  # normalize to [0,1] range
    return _image


def load_and_preprocess_image(_path):
    _image = tf.io.read_file(_path)
    _image = preprocess_image(_image)
    return _image


def img_parse(x):
    result = tf.io.parse_tensor(x, out_type=tf.float32)
    result = tf.reshape(result, [256, 256, 3])#具体参数根据自己网络模型调整
    return result


def label_parse(y):
    result = tf.io.parse_tensor(y, out_type=tf.int64)
    result = tf.reshape(result, [2])
    return result
    


def tfrecord_create(_data_path, _img_tfrecord_abspath, _label_tfrecord_abspath):
    '''制作训练样本tdrecord格式文件'''
    all_imgs_path = glob.glob(os.path.join(_data_path, '*.jpg'))
    random.shuffle(all_imgs_path)
    label2index_dict = {'dis': 0, 'undis': 1}
    all_labels = [label2index_dict.get(os.path.basename(a_img).split('_')[0])for a_img in all_imgs_path]
    all_labels = tf.one_hot(indices=all_labels, depth=2, on_value=1, off_value=0, axis=-1)#使用了独热编码
    ds_image = tf.data.Dataset.from_tensor_slices(all_imgs_path)
    ds_image = ds_image.map(load_and_preprocess_image)
    ds_image = ds_image.map(tf.io.serialize_tensor)
    ds_image_record = tf.data.experimental.TFRecordWriter(_img_tfrecord_abspath)
    ds_image_record.write(ds_image)

    ds_label = tf.data.Dataset.from_tensor_slices(tf.cast(all_labels, tf.int64))
    ds_label = ds_label.map(tf.io.serialize_tensor)
    ds_label_record = tf.data.experimental.TFRecordWriter(_label_tfrecord_abspath)
    ds_label_record.write(ds_label)


def tfrecord_read_decode(_img_tfrecord_abspath, _label_tfrecord_abspath, _ratio=0.2):
    '''读取和解码训练样本tfrecord格式文件'''
    img_ds = tf.data.TFRecordDataset(_img_tfrecord_abspath)
    _tem = 0
    img_length = len([_tem for tem in img_ds])
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    img_ds = img_ds.map(img_parse, num_parallel_calls=AUTOTUNE)
    label_ds = tf.data.TFRecordDataset(_label_tfrecord_abspath)
    _tem = 0
    label_length = len([_tem for tem in label_ds])
    label_ds = label_ds.map(label_parse, num_parallel_calls=AUTOTUNE)

    if img_length != label_length:
        print('......严重错误:图片数据和标签数据tfrecord格式文件长度不一致!')
    else:
        img_label_ds = tf.data.Dataset.zip((img_ds, label_ds))
        # 训练数据、测试数据划分
        verify_ds_num = int(img_length*_ratio)
        train_ds_num = img_length - verify_ds_num
        print('样本总数据{}个, 训练数据有{}个, 验证数据有{}个.'.format(img_length, train_ds_num, verify_ds_num))
        train_ds = img_label_ds.skip(verify_ds_num)
        verify_ds = img_label_ds.take(verify_ds_num)
    return train_ds, train_ds_num, verify_ds, verify_ds_num

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

John H.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值