tf.dataset 使用 python generator 加载和预处理数据,dataset.map 对数据进行额外加工

python部分 使用 generator 输入和打乱数据到 dataset,
tf.graph部分 使用 dataset.map 对输入数据进行加工
tensorflow 官方指导链接:https://tensorflow.google.cn/programmers_guide/datasets

import tensorflow as tf
# import tensorlayer as tl
import numpy as np

# 生成数据集
# x_train, y_train, x_test, y_test = tl.files.load_cifar10_dataset(path='../datasets')
x_train = np.random.uniform(0, 1, [5000, 32, 32, 3])
y_train = np.random.randint(0, 10, [5000, ])


# 定义生成器
def next_train_batch():
    # 打乱数据
    ids = np.arange(len(x_train))
    np.random.shuffle(ids)
    xs = [x_train[i] for i in ids]
    ys = [y_train[i] for i in ids]

    for i in range(len(xs)):
        yield xs[i], ys[i]


# 使用 tf 处理数据
def parse(x, y):
    x = tf.image.random_flip_left_right(x)
    return x, y


# sess = tl.utils.set_gpu_fraction(0.01)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.01)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))


# 使用 生成器 和 parse 处理数据
train_dataset = tf.data.Dataset.from_generator(next_train_batch, (tf.float32, tf.int32)).map(parse)
# 设置epoch为2,设置batch_size为505
train_dataset = train_dataset.repeat(2).batch(505)
# 初始化 dataset
train_iter = train_dataset.make_initializable_iterator()
sess.run(train_iter.initializer)
x, y = train_iter.get_next()


iter_count = 0
while True:
    try:
        # 输出 batch 形状
        a, b = sess.run([x, y])
        print(np.shape(a), np.shape(b))
        iter_count += 1
    except tf.errors.OutOfRangeError:
        break
print('iter count', iter_count)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值