tfrecord 的读写数据是真的麻烦,各种不方便,而且还有些坑,不太想讲这个东西,所以这里就打算写个简单的读写模板,可以作为参考。
其实写tfrecord本质只有三个类型: bytes,int64,float。所以我们要保存的数据就转成这三种类型就行了。
另外,这几种类型的数据都是一个list的形式,并且不支持多维数组,如果想要的数据是多维的,那就要转成1维,再读进来以后再转回去。
还有一点是变长数据,这种一般会padding到定长数组,然后再保存,不过这会浪费一些空间,tensorflow其实还是支持变长数据的,但是和定长数据不能一起用……也不知道这是怎么搞的,我觉得理论上应该是可以支持的吧。
写文件:
# coding: utf-8
import os, sys
import time, io
import tensorflow as tf
from PIL import Image
import numpy as np
import random
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def get_data_item():
n = random.randint(1, 10)
fvalues = []
for i in range(n):
fvalues.append(random.random())
return {
"image_path": 't.jpg',
'float_values': [1.0, 2.0, 3.0], # 同样支持numpy 格式:np.array([1.0, 2.0, 3.0])
'float_values2': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape(-1), # 不支持多维数组,reshape成1维,读成tensor以后再reshape回去
'var_values': np.array(fvalues),
}
def get_example(data_item):
# 1. 图片类型数据, 这里读取的方式有多种,最终读个二进制格式的就行, img_width, img_height根据需要可以选用
with tf.gfile.GFile(data_item['image_path'], 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
img = Image.open(encoded_jpg_io)
img_width, img_height = img.size
# 先读图片,然后通过BytesIO转一下
# img = Image.open(data_item['image_path']).convert('RGB')
# img_width, img_height = img.size
# f = io.BytesIO()
# img.save(f, 'JPEG')
# f.seek(0)
# encoded_jpg = f.read()
# 2. 文本类型数据
image_name = os.path.basename(data_item['image_path'])[:random.randi