TensorFlow 中的 CTCGreedyDecoder

ctc_greedy_decoder

tf_export 通过 api_export 类来导出。
add_dispatch_support 将调度处理包装器添加到 TensorFlow Python API 的装饰器。

@tf_export("nn.ctc_greedy_decoder")
@dispatch.add_dispatch_support
def ctc_greedy_decoder(inputs,
                       sequence_length,
                       merge_repeated=True,
                       blank_index=None):
  """Performs greedy decoding on the logits given in input (best path).
  Given a tensor as `inputs`, the `blank_index` parameter defines the class
  index of the blank symbol.
  For example:
  If `blank_index` is equal to 1:
  >>> inf = float("inf")
  >>> logits = tf.constant([[[   0., -inf, -inf],
  ...                        [ -2.3, -inf, -0.1]],
  ...                       [[ -inf, -0.5, -inf],
  ...                        [ -inf, -inf, -0.1]],
  ...                       [[ -inf, -inf, -inf],
  ...                        [ -0.1, -inf, -2.3]]])
  >>> seq_lens = tf.constant([2, 3])
  >>> outputs = tf.nn.ctc_greedy_decoder(
  ...     logits,
  ...     seq_lens,
  ...     blank_index=1)
  Notes:
  - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks
    as regular elements when computing the probability of a sequence.
  - Default `blank_index` is `(num_classes - 1)`, unless overriden.
  If `merge_repeated` is `True`, merge repeated classes in output.
  This means that if consecutive logits' maximum indices are the same,
  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
  is the blank label) becomes
    * `A B B B` if `merge_repeated=True`.
    * `A B B B B` if `merge_repeated=False`.
  Args:
    inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
      The logits.
    sequence_length: 1-D `int32` vector containing sequence lengths, having size
      `[batch_size]`.
    merge_repeated: Boolean.  Default: True.
    blank_index: (Optional). Default: `num_classes - 1`. Define the class index
      to use for the blank label. Negative values will start from num_classes,
      ie, -1 will reproduce the ctc_greedy_decoder behavior of using
      num_classes - 1 for the blank symbol, which corresponds to the default.
  Returns:
    A tuple `(decoded, neg_sum_logits)` where
    decoded: A single-element list. `decoded[0]`
      is an `SparseTensor` containing the decoded outputs s.t.:
      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
        The rows store: `[batch, time]`.
      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
        The vector stores the decoded classes.
      `decoded.dense_shape`: Shape vector, size `(2)`.
        The shape values are: `[batch_size, max_decoded_length]`
    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
        sequence found, the negative of the sum of the greatest logit at each
        timeframe.
  """

gen_ctc_ops.py文件由 tf_gen_op_wrapper_private_py 根据 tensorflow/python/BUILD 中的信息生成。C++ 驼峰格式的函数名会转换为 Python 的小写下划线形式。
CTCGreedyDecoderOp 对输入中给出的 logits 执行贪婪解码(最佳路径)。
返回一个 SparseTensor 列表和存储每个时间帧最大 logit 负数和的矩阵。

  outputs = gen_ctc_ops.ctc_greedy_decoder(
      inputs,
      sequence_length,
      merge_repeated=merge_repeated,
      blank_index=blank_index)
  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
                                      decoded_shape)], log_probabilities)

REGISTER_OP(“CTCGreedyDecoder”)

REGISTER_OP 注册算子。

REGISTER_OP("CTCGreedyDecoder")
    .Input("inputs: T")
    .Input("sequence_length: int32")
    .Attr("merge_repeated: bool = false")
    .Attr("blank_index: int = -1")
    .Output("decoded_indices: int64")
    .Output("decoded_values: int64")
    .Output("decoded_shape: int64")
    .Output("log_probability: T")
    .Attr("T: {float, double} = DT_FLOAT")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle inputs;
      ShapeHandle sequence_length;

      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));

      // Get batch size from inputs and sequence_length.
      DimensionHandle batch_size;
      TF_RETURN_IF_ERROR(
          c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));

      DimensionHandle total_decoded_outputs = c->UnknownDim();
      c->set_output(0, c->Matrix(total_decoded_outputs, 2));
      c->set_output(1, c->Vector(total_decoded_outputs));
      c->set_output(2, c->Vector(2));
      c->set_output(3, c->Matrix(batch_size, 1));
      return Status::OK();
    });

REGISTER_OP

通过 OpDefBuilderWrapper 来构造 OpDef

#define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \
  static ::tensorflow::InitOnStartupMarker const register_op##ctr         \
      TF_ATTRIBUTE_UNUSED =                                               \
          TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
          << ::tensorflow::register_op::OpDefBuilderWrapper(name)

#define REGISTER_OP(name)        \
  TF_ATTRIBUTE_ANNOTATE("tf:op") \
  TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)

REGISTER_CPU

Name 本质上是 KernelDefBuilder 对象,在内部创建 KernelDef
KernelDefBuilder::Device 设置设备类型。
KernelDefBuilder::TypeConstraint 设置类型约束。
REGISTER_KERNEL_BUILDER 通过 OpKernelRegistrar 类完成 kernel 函数的注册。

#define REGISTER_CPU(T)                                                   \
  REGISTER_KERNEL_BUILDER(                                                \
      Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      CTCGreedyDecoderOp<T>);

REGISTER_CPU(float);
REGISTER_CPU(double);

CTCGreedyDecoderOp

OpKernelConstruction::GetAttr 获取属性值。

template <typename T>
class CTCGreedyDecoderOp : public OpKernel {
 public:
  explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("blank_index", &blank_index_));
  }

CTCGreedyDecoderOp::Compute

CTCGreedyDecoderOp::Compute
CTCDecodeHelper::ValidateInputsGenerateOutputs
CTCDecodeHelper::StoreAllDecodedSequences

输入为 Tensor 指针。
OpOutputList 是由单个命名输出组成的输出张量列表。
CTCDecodeHelper::ValidateInputsGenerateOutputs 验证输入并生成输出张量。
TensorShape 表示一个张量的形状。

  void Compute(OpKernelContext* ctx) override {
    const Tensor* inputs;
    const Tensor* seq_len;
    Tensor* log_prob = nullptr;
    OpOutputList decoded_indices;
    OpOutputList decoded_values;
    OpOutputList decoded_shape;
    OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
                            ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
                            &decoded_values, &decoded_shape));
    const TensorShape& inputs_shape = inputs->shape();

TTypes::UnalignedConstMatrixTensorMap<Tensor<data_type, rank>> 类型,用于在代码的另一部分分配和拥有的内存之上创建张量。它允许将任何分配的内存视为张量。此类的实例不拥有存储数据的内存。TensorMap 不可调整大小,因为它不拥有存储其数据的内存。
TensorShapeBase::dim_size 返回指定维度的大小。

    std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
    const int64_t max_time = inputs_shape.dim_size(0);
    const int64_t batch_size = inputs_shape.dim_size(1);
    const int64_t num_classes_raw = inputs_shape.dim_size(2);
    OP_REQUIRES(
        ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
        errors::InvalidArgument("num_classes cannot exceed max int"));
    const int num_classes = static_cast<const int>(num_classes_raw);

Tensor::tensor 返回嵌套定义的 TTypes:Tensor 对象。
把每个时间片上的数据构造为 TTypes::UnalignedConstMatrix,追加到input_list_t数组。
Tensor::vec 返回一个一维 TTypes::Vec
Tensor::matrix 返回一个二维 TTypes::Matrix
Tensor::setZerolog_prob_t清零。

    auto inputs_t = inputs->tensor<T, 3>();

    input_list_t.reserve(max_time);
    for (std::size_t t = 0; t < max_time; ++t) {
      input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
                                batch_size, num_classes);
    }
    auto seq_len_t = seq_len->vec<int32>();
    auto log_prob_t = log_prob->matrix<T>();

    log_prob_t.setZero();

    int blank_index =
        (blank_index_ < 0) ? num_classes + blank_index_ : blank_index_;
    OP_REQUIRES(ctx, FastBoundsCheck(blank_index, num_classes),
                errors::InvalidArgument("blank_index expected to be between ",
                                        -num_classes, " and ", num_classes - 1,
                                        " but was ", blank_index_));

decode函数在循环中处理单 batch 的数据。
sequences存储每个批次的每个路径的解码值,所以是三层嵌套。GreedyDecoder 只生成一条路径。
seq_len_t数组中获取每个序列的长度。
input_list_t[t]的形状为[batch_size, num_classes]RowMax 找到当前批次的最大概率值及其对应索引。
log_prob_t累积其负数和。
如果不是空白索引且满足重复过滤条件,则添加到路径中。

    // Perform best path decoding
    std::vector<std::vector<std::vector<int> > > sequences(batch_size);
    auto decode = [&](const int64_t begin, const int64_t end) {
      for (int b = begin; b < end; ++b) {
        sequences[b].resize(1);
        auto &sequence = sequences[b][0];
        int prev_indices = -1;
        for (int t = 0; t < seq_len_t(b); ++t) {
          int max_class_indices;
          OP_REQUIRES(ctx, input_list_t[t].dimension(1) > 0,
                      errors::InvalidArgument("Invalid input dimensions."));
          log_prob_t(b, 0) +=
              -RowMax<T>(input_list_t[t], b, &max_class_indices);
          if (max_class_indices != blank_index &&
              !(merge_repeated_ && max_class_indices == prev_indices)) {
            sequence.push_back(max_class_indices);
          }
          prev_indices = max_class_indices;
        }
      }
    };

DeviceBase::tensorflow_cpu_worker_threads 返回嵌套定义的结构体 DeviceBase::CpuWorkerThreads,其中存储了线程数和线程池指针。
Shard 函数。
CTCDecodeHelper::StoreAllDecodedSequencessequences转换为3个 OpOutputList

    const int64_t kCostPerUnit = 50 * max_time * num_classes;
    const int64_t total = batch_size;
    const DeviceBase::CpuWorkerThreads& worker_threads =
        *ctx->device()->tensorflow_cpu_worker_threads();
    Shard(worker_threads.num_threads, worker_threads.workers, total,
          kCostPerUnit, decode);

    OP_REQUIRES_OK(
        ctx, decode_helper_.StoreAllDecodedSequences(
                 sequences, &decoded_indices, &decoded_values, &decoded_shape));
  }

CTCDecodeHelper 用于转换并保存结果。
TF_DISALLOW_COPY_AND_ASSIGN 禁止拷贝构造和赋值构造。

 private:
  CTCDecodeHelper decode_helper_;
  bool merge_repeated_;
  int blank_index_;

  TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};

CTCDecodeHelper

 public:
  CTCDecodeHelper() : top_paths_(1) {}

  inline int GetTopPaths() const { return top_paths_; }
  void SetTopPaths(int tp) { top_paths_ = tp; }

CTCDecodeHelper::ValidateInputsGenerateOutputs

CTCDecodeHelper::ValidateInputsGenerateOutputs
OpKernelContext::input
OpKernelContext::allocate_output
OpKernelContext::output_list

OpKernelContext::input 根据名字得到对应的输入张量。

  Status ValidateInputsGenerateOutputs(
      OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
      Tensor** log_prob, OpOutputList* decoded_indices,
      OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
    Status status = ctx->input("inputs", inputs);
    if (!status.ok()) return status;
    status = ctx->input("sequence_length", seq_len);
    if (!status.ok()) return status;

获取形状和维度信息。
DECLARE_ERROR 生成和使用错误状态。
TensorShapeUtils::IsVector 根据维度信息判断是否为向量。

    const TensorShape& inputs_shape = (*inputs)->shape();

    if (inputs_shape.dims() != 3) {
      return errors::InvalidArgument("inputs is not a 3-Tensor");
    }
    if (inputs_shape.num_elements() == 0) {
      return errors::InvalidArgument("inputs must not be empty");
    }

    const int64_t max_time = inputs_shape.dim_size(0);
    const int64_t batch_size = inputs_shape.dim_size(1);

    if (max_time == 0) {
      return errors::InvalidArgument("max_time is 0");
    }
    if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
      return errors::InvalidArgument("sequence_length is not a vector");
    }

    if (!(batch_size == (*seq_len)->dim_size(0))) {
      return errors::FailedPrecondition(
          "len(sequence_length) != batch_size.  ",
          "len(sequence_length):  ", (*seq_len)->dim_size(0),
          " batch_size: ", batch_size);
    }

    auto seq_len_t = (*seq_len)->vec<int32>();

    for (int b = 0; b < batch_size; ++b) {
      if (!(seq_len_t(b) <= max_time)) {
        return errors::FailedPrecondition("sequence_length(", b,
                                          ") <= ", max_time);
      }
    }

OpKernelContext::allocate_output 分配log_prob的空间。
OpKernelContext::output_list 根据名称得到对应的 OpOutputList

    Status s = ctx->allocate_output(
        "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
    if (!s.ok()) return s;

    s = ctx->output_list("decoded_indices", decoded_indices);
    if (!s.ok()) return s;
    s = ctx->output_list("decoded_values", decoded_values);
    if (!s.ok()) return s;
    s = ctx->output_list("decoded_shape", decoded_shape);
    if (!s.ok()) return s;

    return Status::OK();
  }

CTCDecodeHelper::StoreAllDecodedSequences

sequences以3个输出变量decoded_indicesdecoded_valuesdecoded_shape表示。
top_paths_为最优路径的数量。
对于每个 batch 序列,计算每个最优路径的条目数。

  // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
  Status StoreAllDecodedSequences(
      const std::vector<std::vector<std::vector<int> > >& sequences,
      OpOutputList* decoded_indices, OpOutputList* decoded_values,
      OpOutputList* decoded_shape) const {
    // Calculate the total number of entries for each path
    const int64_t batch_size = sequences.size();
    std::vector<int64_t> num_entries(top_paths_, 0);

    // Calculate num_entries per path
    for (const auto& batch_s : sequences) {
      CHECK_EQ(batch_s.size(), top_paths_);
      for (int p = 0; p < top_paths_; ++p) {
        num_entries[p] += batch_s[p].size();
      }
    }

对于每个最优路径,根据num_entries数组中对应的条目数申请内存。
OpOutputList::allocate 调用 OpKernelContext::allocate_output 创建 Tensor 并返回其指针。
indices_tvalues_tshape_t分别为最优路径的索引、标签值和二维形状。

    for (int p = 0; p < top_paths_; ++p) {
      Tensor* p_indices = nullptr;
      Tensor* p_values = nullptr;
      Tensor* p_shape = nullptr;

      const int64_t p_num = num_entries[p];

      Status s =
          decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
      if (!s.ok()) return s;
      s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
      if (!s.ok()) return s;
      s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
      if (!s.ok()) return s;

      auto indices_t = p_indices->matrix<int64_t>();
      auto values_t = p_values->vec<int64_t>();
      auto shape_t = p_shape->vec<int64_t>();

      int64_t max_decoded = 0;
      int64_t offset = 0;

对于每个 batch,p_batch为序列的最优路径。num_decoded为路径长度。拷贝到values_t中。
indices_t中填btoffset为不同 batch 的偏移。

      for (int64_t b = 0; b < batch_size; ++b) {
        auto& p_batch = sequences[b][p];
        int64_t num_decoded = p_batch.size();
        max_decoded = std::max(max_decoded, num_decoded);
        if (num_decoded > 0) {
          DCHECK_NE(values_t.data(), nullptr)
              << "values_t should not be nullptr: p_num=" << p_num
              << " num_decoded=" << num_decoded;
          DCHECK_LT(offset, values_t.size())
              << "offset should be smaller than values_t.size()";
          std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
        }
        for (int64_t t = 0; t < num_decoded; ++t, ++offset) {
          indices_t(offset, 0) = b;
          indices_t(offset, 1) = t;
        }
      }

      shape_t(0) = batch_size;
      shape_t(1) = max_decoded;
    }
    return Status::OK();
  }
 private:
  int top_paths_;
  TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);

OpKernelContext::input

根据名字获取索引,然后设置到 tensorflow::gtl::InlinedVector 内元素 TensorValue 的张量。

  int index;
  TF_RETURN_IF_ERROR(get_input_index(name, &index));
  if (input_is_ref(index)) {
    return errors::InvalidArgument("OpKernel used ref input name '", name,
                                   "' when non-ref input was expected");
  }
  *tensor = (*params_->inputs)[index].tensor;
  return Status::OK();

RowMax

UnalignedConstMatrix
CHECK_LT
找到矩阵中指定行的最大值,返回最大值并记录其列索引到c中。

template <typename T>
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
                int* c) {
  *c = 0;
  CHECK_LT(0, m.dimension(1));
  auto p = m(r, 0);
  for (int i = 1; i < m.dimension(1); ++i) {
    if (m(r, i) > p) {
      p = m(r, i);
      *c = i;
    }
  }
  return p;
}

Shard

Shard
ThreadPool::ParallelFor

GetPerThreadMaxParallelism 返回全局变量 per_thread_max_parallelism
如果小于等于1,则直接执行work任务函数。

  CHECK_GE(total, 0);
  if (total == 0) {
    return;
  }
  max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism());
  if (max_parallelism <= 1) {
    // Just inline the whole work since we only have 1 thread (core).
    work(0, total);
    return;
  }

ThreadPool::ParallelFor 线程并行处理。

  if (max_parallelism >= workers->NumThreads()) {
    workers->ParallelFor(total, cost_per_unit, work);
    return;
  }

Sharder::Do 方式已经废弃了。

  Sharder::Do(
      total, cost_per_unit, work,
      [&workers](Sharder::Closure c) { workers->Schedule(c); },
      max_parallelism);

ThreadPool::ParallelFor

调用 ThreadPoolDevice::parallelFor 函数来处理。

  CHECK_GE(total, 0);
  CHECK_EQ(total, (int64_t)(Eigen::Index)total);
  threadpool_device_->parallelFor(
      total, Eigen::TensorOpCost(0, 0, cost_per_unit),
      [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });

ThreadPoolDevice

// CPU device implementation.
class ThreadPoolDevice : public LocalDevice {
 public:
  ThreadPoolDevice(const SessionOptions& options, const string& name,
                   Bytes memory_limit, const DeviceLocality& locality,
                   Allocator* allocator);
  ~ThreadPoolDevice() override;

  Allocator* GetAllocator(AllocatorAttributes attr) override;
  Allocator* GetScopedAllocator(AllocatorAttributes attr,
                                int64_t step_id) override;
  ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
    return scoped_allocator_mgr_.get();
  }
  Status MakeTensorFromProto(const TensorProto& tensor_proto,
                             const AllocatorAttributes alloc_attrs,
                             Tensor* tensor) override;
  void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
                              const DeviceContext* device_context,
                              StatusCallback done) override;

  Status Sync() override { return Status::OK(); }

  void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
  void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
                    AsyncOpKernel::DoneCallback done) override;

 private:
  void LogInputs(OpKernel* op_kernel, OpKernelContext* context);
  void LogOutputs(OpKernel* op_kernel, OpKernelContext* context);

  Allocator* allocator_;  // Not owned
  std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
  NodeFileWriter* node_file_writer_ = nullptr;  // not owned
};

参考资料:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值