generator,字面翻译即引擎。可以源源不断批量生成数据。
看看keras经典例子
def generate_arrays_from_file(path):
while True:
with open(path) as f:
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)
a = [4,16,22,23,26,27,28,36,44,51,56,64
,71,74,75,80,82,88, 84,100, 104,118,123,129,130]
def data_gen():
while True:
for i in a:
x=i
yield x
y = data_gen()
x = next(y)
print(x)
x = next(y)
print(x)
x = next(y)
print(x)
4
16
22
tf.data生成器用法
测试集
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices((test_x, test_y))
ds = ds.batch(1)
x, y = ds.make_one_shot_iterator().get_next()
训练集
ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
ds = ds.repeat().shuffle(1000).batch(BATCH_SIZE)
x, y = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
a = sess.run(x)
b = sess.run(y)
更多数学原理小文请关注公众号:未名方略