tf.data.Dataset.map()函数的用法
官方解释
此转换适用于此数据集的每个元素,并返回包含已转换元素的新数据集,其顺序与输入中显示的顺序相同
import tensorflow as tf
生成数据集
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
list(dataset.as_numpy_iterator())
[array([0, 1, 2, 3, 4, 5], dtype=int64),
array([6, 7, 8, 9], dtype=int64)]
用lambda来转换
dataset = dataset.map(lambda x: x + 10)
list(dataset.as_numpy_iterator())
[array([10, 11, 12, 13, 14, 15], dtype=int64),
array([16, 17, 18, 19], dtype=int64)]
用 函数 sum(a) 来转换
def sum(a):
return a+100
dataset = dataset.map(sum)
list(dataset.as_numpy_iterator())
[array([110, 111, 112, 113, 114, 115], dtype=int64),
array([116, 117, 118, 119], dtype=int64)]