python部分 使用 generator 输入和打乱数据到 dataset,
tf.graph部分 使用 dataset.map 对输入数据进行加工
tensorflow 官方指导链接:https://tensorflow.google.cn/programmers_guide/datasets
import tensorflow as tf
# import tensorlayer as tl
import numpy as np
# 生成数据集
# x_train, y_train, x_test, y_test = tl.files.load_cifar10_dataset(path='../datasets')
x_train = np.random.uniform(0, 1, [5000, 32, 32, 3])
y_train = np.random.randint(0, 10, [5000, ])
# 定义生成器
def next_train_batch():
# 打乱数据
ids = np.arange(len(x_train))
np.random.shuffle(ids)
xs = [x_train[i] for i in ids]
ys = [y_train[i] for i in ids]
for i in range(len(xs)):
yield xs[i], ys[i]
# 使用 tf 处理数据
def parse(x, y):
x = tf.image.random_flip_left_right(x)
return x, y
# sess = tl.utils.set_gpu_fraction(0.01)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.01)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
# 使用 生成器 和 parse 处理数据
train_dataset = tf.data.Dataset.from_generator(next_train_batch, (tf.float32, tf.int32)).map(parse)
# 设置epoch为2,设置batch_size为505
train_dataset = train_dataset.repeat(2).batch(505)
# 初始化 dataset
train_iter = train_dataset.make_initializable_iterator()
sess.run(train_iter.initializer)
x, y = train_iter.get_next()
iter_count = 0
while True:
try:
# 输出 batch 形状
a, b = sess.run([x, y])
print(np.shape(a), np.shape(b))
iter_count += 1
except tf.errors.OutOfRangeError:
break
print('iter count', iter_count)