tf.estimator tf.data 混合不同的数据

import tensorflow as tf

dataset_1 = tf.data.Dataset.from_tensors(1).repeat(20)
dataset_2 = tf.data.Dataset.from_tensors(2).repeat(20)

dataset = tf.data.Dataset.zip((dataset_1, dataset_2))
dataset = dataset.batch(8)
dataset = dataset.map(lambda a, b: tf.concat([a, b], 0))

one_shot_iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
for i in range(3):
    print(sess.run(one_shot_iterator.get_next()))

打印结果:
[1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2]
[1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2]
[1 1 1 1 2 2 2 2]

©️2020 CSDN 皮肤主题: 创作都市 设计师:CSDN官方博客 返回首页