主要参考资料(Tensroflow官方文档):https://tensorflow.google.cn/tutorials/load_data/image
TFRecord 文件是一种用来存储一串二进制 blob 的简单格式。通过将多个示例打包进同一个文件内,TensorFlow 能够一次性读取多个示例,当使用一个远程存储服务,如 GCS 时,这对性能来说尤其重要。
我的项目是一个二分类问题,标签使用了独热编码,为了后期直接把图片和标签数据zip到一起喂给网络,所以制作了图片和标签的tfrecord文件,以及读取和解析的函数。中间有很多坑,折腾了很久,特别烦人,现将代码贴出来,给有用的人提供参考。
如有其他问题,欢迎打扰!
"""Created on Thu Jul 1 15:47:53 2021
@author: John_Huang"""
import os
import sys
import glob
import time
import tensorflow as tf
import numpy
import random
def preprocess_image(_image):
_image = tf.image.decode_jpeg(_image, channels=3)
_image = tf.image.resize(_image, [256, 256])#具体参数根据自己网络模型调整
_image /= 255.0 # normalize to [0,1] range
return _image
def load_and_preprocess_image(_path):
_image = tf.io.read_file(_path)
_image = preprocess_image(_image)
return _image
def img_parse(x):
result = tf.io.parse_tensor(x, out_type=tf.float32)
result = tf.reshape(result, [256, 256, 3])#具体参数根据自己网络模型调整
return result
def label_parse(y):
result = tf.io.parse_tensor(y, out_type=tf.int64)
result = tf.reshape(result, [2])
return result
def tfrecord_create(_data_path, _img_tfrecord_abspath, _label_tfrecord_abspath):
'''制作训练样本tdrecord格式文件'''
all_imgs_path = glob.glob(os.path.join(_data_path, '*.jpg'))
random.shuffle(all_imgs_path)
label2index_dict = {'dis': 0, 'undis': 1}
all_labels = [label2index_dict.get(os.path.basename(a_img).split('_')[0])for a_img in all_imgs_path]
all_labels = tf.one_hot(indices=all_labels, depth=2, on_value=1, off_value=0, axis=-1)#使用了独热编码
ds_image = tf.data.Dataset.from_tensor_slices(all_imgs_path)
ds_image = ds_image.map(load_and_preprocess_image)
ds_image = ds_image.map(tf.io.serialize_tensor)
ds_image_record = tf.data.experimental.TFRecordWriter(_img_tfrecord_abspath)
ds_image_record.write(ds_image)
ds_label = tf.data.Dataset.from_tensor_slices(tf.cast(all_labels, tf.int64))
ds_label = ds_label.map(tf.io.serialize_tensor)
ds_label_record = tf.data.experimental.TFRecordWriter(_label_tfrecord_abspath)
ds_label_record.write(ds_label)
def tfrecord_read_decode(_img_tfrecord_abspath, _label_tfrecord_abspath, _ratio=0.2):
'''读取和解码训练样本tfrecord格式文件'''
img_ds = tf.data.TFRecordDataset(_img_tfrecord_abspath)
_tem = 0
img_length = len([_tem for tem in img_ds])
AUTOTUNE = tf.data.experimental.AUTOTUNE
img_ds = img_ds.map(img_parse, num_parallel_calls=AUTOTUNE)
label_ds = tf.data.TFRecordDataset(_label_tfrecord_abspath)
_tem = 0
label_length = len([_tem for tem in label_ds])
label_ds = label_ds.map(label_parse, num_parallel_calls=AUTOTUNE)
if img_length != label_length:
print('......严重错误:图片数据和标签数据tfrecord格式文件长度不一致!')
else:
img_label_ds = tf.data.Dataset.zip((img_ds, label_ds))
# 训练数据、测试数据划分
verify_ds_num = int(img_length*_ratio)
train_ds_num = img_length - verify_ds_num
print('样本总数据{}个, 训练数据有{}个, 验证数据有{}个.'.format(img_length, train_ds_num, verify_ds_num))
train_ds = img_label_ds.skip(verify_ds_num)
verify_ds = img_label_ds.take(verify_ds_num)
return train_ds, train_ds_num, verify_ds, verify_ds_num