12_TFRecord数据打包
4.TensorFlow中的数据读取方式
4.1直接将数据加载到内存中
在处理小规模数据,比如mnist的时候可以将数据直接存入内存。但是对于较大的数据集,这是很麻烦的。因此好的方法就是将数据存放在磁盘上,在需要的时候再加载进内存,比如训练时每次加载一个batchsize的数据。
但是这样的方法也是存在问题的,如下图,因为计算和加载数据是不同步的,若计算设备完成 计算时数据还未完成加载,这个时候就会造成计算训练处的延迟,导致资源浪费。
因此更好地办法是预取数据并且使用独立的线程进行加载和训练。
4.2TFRecord
TensorFlow提供了TFRecord:高效的TensorFlow文件格式来统一存储数据,TFRecord就是一个简单的包含序列化输入数据的二进制文件,序列化是基于协议缓冲区的,这样可以快速地实现数据的复制,移动,读取和存储。
如下图,TensorFlow数据读取机制(数据->文件名队列->内存队列->计算设备读取数据):
首先要理解EPOCH的概念,一个epoch就是把数据集里的所有图片都学习了一遍,比如一个数据集有100张图片,训练时batchsize为10,那么一个epoch就是把这100张图片都计算了一遍,而一个epoch就包含了10(总量100除以每批的数据量10)个batch的数据。文件名队列就构成了数据的队列,设定shuffle参数可以让数据乱序。
比如原来有A,B,C三张图片,若设定训练3个epoch,shuffle=false,则最终文件名队列中的结果就是[A,B,C,A,B,C,A,B,C]。
比如原来有A,B,C三张图片,若设定训练3个epoch,shuffle=true,则最终文件名队列中的结果可能是[A,C,B,B,C,A,C,A,B]。每个epoch内的ABC顺序不一定。
4.3数据操作实践cifar_10数据集
4.3.1下载并解析cifar10
首先下载cifar10数据集,由于官网下载可能很慢,我将数据集放在了百度网盘里(提取码o5yu):[Download cifar10]。
我的目录结构:
.
├── cifar10_bin_to_jpg.py
├── cifar10_jpg_to_tfrecord.py
├── test_data
│ └── data_batch_6
└── train_data
├── batches.meta
├── cifar10_jpgs
├── data_batch_1
├── data_batch_2
├── data_batch_3
├── data_batch_4
└── data_batch_5
首先解析下载下来的二进制文件,将里面的数据解析并存放到每个数据到train_data/cifar10_jpgs目录中对应label的目录下。
cifar10_bin_to_jpg.py:
import os
import pickle
import numpy as np
import cv2
classes= ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
data = []
labels = []
for i in range(1,6):
file_name = "train_data/data_batch_"+str(i)
result = unpickle(file_name)
data += list(result[b"data"])
labels += list(result[b"labels"])
print(file_name+" loaded.")
imgs = np.reshape(data, [-1, 3, 32, 32])
for i in range(imgs.shape[0]):
im_data = imgs[i, ...]
im_data = np.transpose(im_data, [1, 2, 0])
im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
f = "{}/{}".format("train_data/cifar10_jpgs", classes[labels[i]])
if not os.path.exists(f):
os.mkdir(f)
cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)
print("All Done.")
完成后可以在train_data/cifar10_jpgs目录下看到十个类别的数据里面的图片数据。
4.3.将解析后的cifar10打包成tfrecord文件。
涉及到的tf api:
数据读取:tf.train.string_input_producer
数据解析:tf.TFRecordReader tf.parse_single_example
数据写入:tf.python_io.TFRecordWriter
数据写入:
writer = tf.python_io.TFRecordWriter()
example = tf.train.example() Feature{图像数据,图像label等。。}
//序列化后写入
writer.write(example.SerializeToString())
writer.close()
import tensorflow as tf
import glob
import cv2
import numpy as np
classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
idx = 0
im_data = []
im_labels = []
for class_ in classes:
path = "train_data/cifar10_jpgs/" + class_
im_list = glob.glob(path + "/*")
im_label = [idx for i in range(im_list.__len__())]
idx += 1
im_data += im_list
im_labels += im_label
#实例化一个TFRecordWriter对象
tfrecord_file = "train_data/train_data.tfrecord"
writer = tf.python_io.TFRecordWriter(tfrecord_file)
index = [i for i in range(im_data.__len__())]
np.random.shuffle(index)
print("strat make")
print("total:",im_data.__len__())
for i in range(im_data.__len__()):
if i%5000 is 0 :
print(i)
im_d = im_data[index[i]]
im_l = im_labels[index[i]]
data = cv2.imread(im_d)
#tf.train.Example是一个存放数据的结构,一个example包含一个Features对象
ex = tf.train.Example(
features = tf.train.Features(
feature = {
"image":tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[data.tobytes()])),
"label": tf.train.Feature(
int64_list=tf.train.Int64List(
value=[im_l])),
}
)
)
writer.write(ex.SerializeToString())
writer.close()
print("make tfrecord_file done.")