padded_batch API如下
padded_batch(
batch_size, padded_shapes=None, padding_values=None, drop_remainder=False
)
注意参数drop_remainder用来约束最后一个batch是不是要丢掉,当这个batch样本数少于batch_size时,比如batch_size = 3,最后一个batch只有2个样本。默认是不丢掉
padded_batch是非常见的一个操作,比如对一个变长序列,通过padding操作将每个序列补成一样的长度。
特点:
1)padded_shapes使用默认值或者设置为-1,那么每个batch padding后每个维度就是跟这个者个batch的样本各个维度最大值保持一致
2)当shape固定为特定的size时,那么每个batch的shape就是一样的。如果
A = tf.data.Dataset.range(1, 6, output_type=tf.int32).map(lambda x: tf.fill([x], x))
for item in A.as_numpy_iterator():
print(item)
结果如下:
[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]
padded_batch操作:
-
padded_shapes不设置或者设置为-1
padded_shapes设置为-1跟不设置该参数的效果一样,就是按每个batch里的最大的size去进行padding
B = A.padded_batch(2, padded_shapes = [-1])
for item in B.as_numpy_iterator():
print("*" * 20)
print(item)
打印结果如下:
可以看出事每个batch的里的shape保持一致,长度不够的补0
********************
[[1 0]
[2 2]]
********************
[[3 3 3 0]
[4 4 4 4]]
********************
[[5 5 5 5 5]]
-
padded_shapes设置为固定值
B = A.padded_batch(2, padded_shapes = [6])
for item in B.as_numpy_iterator():
print("*" * 20)
print(item)
打印结果:
可见每个batch的每个序列长度都是6,不足就补0
********************
[[1 0 0 0 0 0]
[2 2 0 0 0 0]]
********************
[[3 3 3 0 0 0]
[4 4 4 4 0 0]]
********************
[[5 5 5 5 5 0]]