数据集格式,将所有测试或者训练数据集各自保存在一个文件夹下:如下图:训练数据集
制作.tfrecords格式数据集的代码:
import os
import tensorflow as tf
from PIL import Image #注意Image,后面会用到
# Imagenet图片都保存在/data目录下,里面有1000个子目录,获取这些子目录的名字
classes = os.listdir('G:\\train\\')
cwd1='G:\\test\\'
cwd2='G:\\train\\'
print(classes)
writer1 = tf.python_io.TFRecordWriter("test.tfrecords") # 要生成的文件
writer2 = tf.python_io.TFRecordWriter("train.tfrecords") # 要生成的文件
for index in range(10):
print(index)
name=classes[index]
print(name)
class_path1 = cwd1 + name + '/' #训练数据集的路径
class_path2 = cwd2 + name + '/' #测试数据集的路径
for img_name in os.listdir(class_path1):
img_path = class_path1 + img_name # 每一个图片的地址
img = Image.open(img_path)
img = img.resize((32, 32))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer1.write(example.SerializeToString()) # 序列化为字符串
for img_name2 in os.listdir(class_path2):
img_path2 = class_path2 + img_name2 # 每一个图片的地址
img2 = Image.open(img_path2)
img2 = img2.resize((64, 64))
img_raw2 = img2.tobytes() # 将图片转化为二进制格式
example2 = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw2]))
})) # example对象对label和image数据进行封装
writer2.write(example2.SerializeToString()) # 序列化为字符串
writer1.close()
writer2.close()
读取数据集代码:
def read_and_decode(filename): # 读入dog_train.tfrecords
filename_queue = tf.train.string_input_producer([filename]) # 生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 将image数据和label取出来
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [32 * 32 * 3]) # reshape为128*128的3通道图片
img = tf.cast(img, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.int32) # 在流中抛出label张量
return img, label
#打乱数据
def createBatch(filename, batchsize):
images, labels = read_and_decode(filename)
min_after_dequeue = batchsize
capacity = 3 * batchsize
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batchsize,
capacity=capacity,
min_after_dequeue=min_after_dequeue
)
label_batch = tf.one_hot(label_batch, depth=10)
return image_batch, label_batch
使用细节
多个tfrecord文件
data_dir = '/home/sanyuan/dataset_animal/dataset_tfrecords/'
filenames = [os.path.join(data_dir,'train%d.tfrecords' % ii) for ii in range(1)] #如果有多个文件,直接更改这里即可
filename_queue = tf.train.string_input_producer(filenames)
image = read_and_decode(filename_queue)
tfrecord本质是创建一个文件队列,创建一个内存队列,内存队列不需要创建,文件队列则需要通过tfrecords文件名来创建。
shuffle参数判断远问是否需要打乱,num_epochs迭代的代数100
filename_queue = tf.train.string_input_producer([filename],shuffle=True,num_epochs=100) # 生成一个queue队列