学习完Mnist手写体识别后,想用Cifar数据集制作成方便深度神经网络训练的二进制数据集文件格式。Cifar数据集包含
airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck10种分类,共6万张图片,其中训练集5万张,测试集1万张。一般情况下,训练的数据都是图片,如果直接读取图片进行训练,程序的执行效率不高。采用tfrecords可以把图片和对应的标签放在一个二进制文件中。接下来将介绍如何生成便于训练的Cifar数据集。具体思路是:从Cifar这个文件夹中遍历10个种类 的文件夹,然后根据每个文件夹的路径遍历所有的文件。
import tensorflow as tf
from PIL import Image
import os
image_train_path='./train/'
label_train_path='./test_label.txt'
tfRecord_train='./Cifar_train.tfrecords'
image_test_path='./test/'
label_test_path='./test_label.txt'
tfRecord_test='./Cifar_test.tfrecords'
data_path='./data'
def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName)
num_pic = 0
label_cnt = 0
获取各个子文件夹的文件名
Data_dir = os.listdir(image_path)
对获取的文件名进行排序,方便制作相应的标签
Data_dir = sorted(Data_dir)
for child_dir in Data_dir:
获取每个子文件夹的路径
img_path = image_path + child_dir +'/'
pic_name = os.listdir(img_path)
标签初始化
labels = [0]*10
labels[label_cnt] =1
label_cnt +=1
循环遍历每个子文件夹里面的图片
for pic_path in pic_name:
获取每个文件的路径,然后打开图片
flie_path = img_path + pic_path
img = Image.open(flie_path)
将图片信息转换为二进制信息
img_raw = img.tobytes()
把图片和对应的标签存入一个二进制文件
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString())
num_pic += 1
writer.close()
print("write tfrecord successful")
def generate_tfRecord():
isExists = os.path.exists(data_path)
if not isExists:
os.makedirs(data_path)
print 'the directory was created successfully'
else:
print 'dirctory already exists'
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
以上代码是写入数据函数,接下来是在训练的时候,如何读取训练数据。接着上面的程序继续
def read_tfRecord(tfRecord_path):
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img,[32, 32,3])
img = tf.cast(img, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.float32)
return img, label
def get_tfrecord(num, isTrain=True):
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size =num,
num_threads =2,
capacity = 2000,
min_after_dequeue = 1000)
return img_batch, label_batch
def main():
generate_tfRecord()
if __name__ == '__main__':
main()
运行这个上述代码即可得到训练集文件和测试集文件。