TensorFlow的Dataset的padded_batch使用

 

padded_batch API如下

padded_batch(
    batch_size, padded_shapes=None, padding_values=None, drop_remainder=False
)

注意参数drop_remainder用来约束最后一个batch是不是要丢掉,当这个batch样本数少于batch_size时,比如batch_size = 3,最后一个batch只有2个样本。默认是不丢掉

padded_batch是非常见的一个操作,比如对一个变长序列,通过padding操作将每个序列补成一样的长度。

特点:

1)padded_shapes使用默认值或者设置为-1,那么每个batch padding后每个维度就是跟这个者个batch的样本各个维度最大值保持一致

2)当shape固定为特定的size时,那么每个batch的shape就是一样的。如果

 


A = tf.data.Dataset.range(1, 6, output_type=tf.int32).map(lambda x: tf.fill([x], x))
for item in A.as_numpy_iterator():
  print(item)

结果如下:

[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]

padded_batch操作:

  • padded_shapes不设置或者设置为-1

padded_shapes设置为-1跟不设置该参数的效果一样,就是按每个batch里的最大的size去进行padding

B = A.padded_batch(2, padded_shapes = [-1])
for item in B.as_numpy_iterator():
  print("*" * 20)
  print(item)

打印结果如下:

可以看出事每个batch的里的shape保持一致,长度不够的补0

********************
[[1 0]
 [2 2]]
********************
[[3 3 3 0]
 [4 4 4 4]]
********************
[[5 5 5 5 5]]
  • padded_shapes设置为固定值 

B = A.padded_batch(2, padded_shapes = [6])
for item in B.as_numpy_iterator():
  print("*" * 20)
  print(item)
  

打印结果:

可见每个batch的每个序列长度都是6,不足就补0

********************
[[1 0 0 0 0 0]
 [2 2 0 0 0 0]]
********************
[[3 3 3 0 0 0]
 [4 4 4 4 0 0]]
********************
[[5 5 5 5 5 0]]

 

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值