tensorflow padded_batch的注意事项

tensorflow版本:1.13.1

 

1、在padded_batch中,若函数

出现错误:TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'int'>.

     案例如下,在以下案例中, 出现上述错误的原因是:在paddd_batch中的参数:padding_values, 如果不设置此参数,则没有问题,即此行代码更改为:

dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]))

模型的padding_values为0, 对应着padded_shape中两个tensor的padding_values都为0,此时若是出现另外一种情况:某种情况下想要填充的值不为0,而是为-1,那么可以设置成如下代码即可:

   

dataset = dataset.padded_batch(batch_size, padded_shapes=([1501], [None]), padding_values=(tf.constant(-1, tf.int64), tf.constant(0, tf.int64)))

每个样本example返回两个tensor, 每个tensor的shape不确定(从padded_shape可以看出来是两个tensor,并且形状是变长,不确定的),那么padding_values也需要保持每个tensor的padding_values是两个(example的返回两个tensor对应一个padding_values中的tensor)

 

2、TypeError: Batching of padded sparse tensors is not currently supported

     若是出现上述的错误,那么是由于在解析单个example时,定义的变长tensor特征,即VarLenFeature, 那么只需要针对VarLenFeature的tensor转成dense tensor 即可:

example["text_ids"] = tf.sparse_tensor_to_dense(example["text_ids"])
example["label_ids"] = tf.sparse_tensor_to_dense(example["label_ids"])

 

    def get_train_tfrecord(tfrecord_path=None, num_epochs=1, batch_size=16, shuffle=True):

        dataset = tf.data.TFRecordDataset(tfrecord_path)


        # feature_dict = {"text_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64), "label_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64)}

        def parse_one_example(example):
            feature_dict = {"text_ids": tf.VarLenFeature(tf.int64),
                            "label_ids": tf.VarLenFeature(tf.int64)}
            example = tf.parse_single_example(example, feature_dict)
            example["text_ids"] = tf.sparse_tensor_to_dense(example["text_ids"])
            example["label_ids"] = tf.sparse_tensor_to_dense(example["label_ids"])
            # text_ids = tf.to_int32(example["text_ids"])
            # label_ids = tf.to_int32(example["label_ids"])

            text_ids = example["text_ids"]
            print("text_ids:", text_ids)
            label_ids = example["label_ids"]

            return text_ids, label_ids

        dataset = dataset.map(parse_one_example)

        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
        # dataset = dataset.batch(batch_size=batch_size)
        # dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]))
        dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]), padding_values=(tf.constant(-1, tf.int64), tf.constant(0, tf.int64)))
        dataset = dataset.repeat(num_epochs)
        iterator = dataset.make_one_shot_iterator()
        batch_data = iterator.get_next()
        return batch_data  # batch_text_ids = batch_data[0], batch_label_ids = batch_data[1]

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值