1 TFRecord格式介绍
对于大量的图像数据,TensorFlow提供了一种统一的格式来存储数据——TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据,虽然它的内部格式复杂,但是它可以很好地利用内存,方便地复制和移动,更符合TensorFlow执行引擎的方式。
TFReocrd文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。
message Example{
Features features = 1;
};
message Features{
map<string,Feature> feature = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。
2 图像数据集
本文采用的图像数据集来自stanford car dataset
将数据集的图片全部放入data文件夹下,label文件(我已改名为label.txt)放在与data文件夹同根目录下。
3 将图像转为TFRecord
from __future__ import absolute_import,division,print_function
import numpy as np
import tensorflow as tf
import time
from scipy.misc import imread,imresize
from os import walk
from os.path import join
DATA_DIR = 'data/'
IMG_HEIGHT = 227
IMG_WIDTH = 227
IMG_CHANNELS = 3
NUM_TRAIN = 7000
NUM_VALIDARION = 1144
def read_images(path):
filenames = next(walk(path))[2]
num_files = len(filenames)
images = np.zeros((num_files,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),dtype=np.uint8)
labels = np.zeros((num_files, ), dtype=np.uint8)
f = open('label.txt')
lines = f.readlines()
for i,filename in enumerate(filenames):
img = imread(join(path,filename))
img = imresize(img,(IMG_HEIGHT,IMG_WIDTH))
images[i] = img
labels[i] = int(lines[i])
f.close()
return images,labels
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert(images,labels,name):
num = images.shape[0]
filename = name+'.tfrecords'
print('Writting',filename)
writer = tf.python_io.TFRecordWriter(filename)
for i in range(num):
img_raw = images[i].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(int(labels[i])),
'image_raw': _bytes_feature(img_raw)}))
writer.write(example.SerializeToString())
writer.close()
print('Writting End')
def main(argv):
print('reading images begin')
start_time = time.time()
train_images,train_labels = read_images(DATA_DIR)
duration = time.time() - start_time
print("reading images end , cost %d sec" %duration)
validation_images = train_images[:NUM_VALIDARION,:,:,:]
validation_labels = train_labels[:NUM_VALIDARION]
train_images = train_images[NUM_VALIDARION:,:,:,:]
train_labels = train_labels[NUM_VALIDARION:]
print('convert to tfrecords begin')
start_time = time.time()
convert(train_images,train_labels,'train')
convert(validation_images,validation_labels,'validation')
duration = time.time() - start_time
print('convert to tfrecords end , cost %d sec' %duration)
if __name__ == '__main__':
tf.app.run()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
本文将数据集中的7000张用于训练,1144张用于验证。
4 读取TFRecord文件
from __future__ import absolute_import,division,print_function
import numpy as np
from os.path import join
import tensorflow as tf
import convert_to_tfrecords
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
NUM_CLASSES = 196
IMG_HEIGHT = convert_to_tfrecords.IMG_HEIGHT
IMG_WIDTH = convert_to_tfrecords.IMG_WIDTH
IMG_CHANNELS = convert_to_tfrecords.IMG_CHANNELS
IMG_PIXELS = IMG_HEIGHT * IMG_WIDTH * IMG_CHANNELS
NUM_TRAIN = convert_to_tfrecords.NUM_TRAIN
NUM_VALIDARION = convert_to_tfrecords.NUM_VALIDARION
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={
'label':tf.FixedLenFeature([],tf.int64),
'image_raw':tf.FixedLenFeature([],tf.string)
})
image = tf.decode_raw(features['image_raw'],tf.uint8)
label = tf.cast(features['label'],tf.int32)
image.set_shape([IMG_PIXELS])
image = tf.reshape(image,[IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image,label
def inputs(data_set,batch_size,num_epochs):
if not num_epochs:
num_epochs = None
if data_set == 'train':
file = TRAIN_FILE
else:
file = VALIDATION_FILE
with tf.name_scope('input') as scope:
filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs)
image,label = read_and_decode(filename_queue)
images,labels = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
num_threads=64,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000
)
return images,labels
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
读取一个batch的图像和label只需要调用inputs()函数就行了。
5 结果
结果生成了一个1GB的train.tfrecords和168MB的validation.tfrecords