【代码分析】Tensorflow OpShapeInferenceFn 详解

目录

 

背景

OpShapeInference分析

举个栗子


背景

在我之前的文章Tensorflow自定义算子实现原理 中说明了自定义OP的注册过程,其中开始的第一步通过REGISTER_OP注册算子,有一个类型为OpShapeInferenceFn的参数通过SetShapeFn注册到OpRegistrationData对象中,本文解释OpShapeInferenceFn的作用和实例

 

OpShapeInference分析

OP定义了Input和Output,但是没有定义它们的形状(shape),以矩阵乘法MatMul OP为例,REGISTER_OP的参数中并没有指定它是个[2,3] x [3,4] = [2, 4]的矩阵,还是[5,6] x [6,7] = [5, 7]的矩阵。所以对于OP需要有一个形状推断的函数,即OpShapeInferenceFn的作用。

 

typedef std::function<Status(shape_inference::InferenceContext* c)>
    OpShapeInferenceFn;

OpShapeInferenceFn定义在tensorflow/core/framework/op_def_builder.h,是一个std::function类型

 

ShapeInference 相关类

在tensorflow/core/framework/shape_inference.h中定义了做shape 推断的数据结构,如上图所示,它们关系很简单,以三维矩阵 A[3, 4, 5] 为例

代表A的Shape中rank_ 为3,dims_中有3个DimensionHandle,每个DimenionHandle指向的Dimension中value_分别为3,4,5

ShapeManager统一管理所有的Shape和Dimension的数据,而InferenceContext作为OpShapeInferenceFn的关键参数提供了各种做形状处理的API 来使用Sape和Dimension数据(Google的工程师写代码分层解耦能力刚刚的)

 

举个栗子

抽象的文字描述总是苍白的,用说人话的方式举个2维矩阵乘法MatMul的栗子🌰

REGISTER_OP("MatMul")
    .Input("a: T")
    .Input("b: T")
    .Output("product: T")
    .Attr("transpose_a: bool = false")
    .Attr("transpose_b: bool = false")
    .Attr(
        "T: {bfloat16, half, float, double, int32, int64, complex64, "
        "complex128}")
    .SetShapeFn(shape_inference::MatMulShape);

MatMul 算子定义在tensorflow/core/ops/math_ops.cc,其中SetShapeFn指定了形状推断函数为shape_inference::MatMulShape

Status MatMulShape(shape_inference::InferenceContext* c) {
  ShapeHandle a;
  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));

  ShapeHandle b;
  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));

  bool transpose_a, transpose_b;
  TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
  TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
  DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
  DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);

  // Validate that the inner shapes are compatible.
  DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
  DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
  DimensionHandle merged;
  TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));

  c->set_output(0, c->Matrix(output_rows, output_cols));
  return Status::OK();
}

 shape_inference::MatMulShape定义在tensorflow/core/framework/common_shape_fns.cc

以2维矩阵 A[3, 4] x B[4, 5] = D[3, 5] 为例分析代码如下:

  • c->input(0)获得A的ShapeHandle,c->WithRank比较第一个参数A的rank( = 2)与第二参数2相等,从而创建了一个rank=2的ShapeHandle a输出
  • c->input(1)获得B的ShapeHandle,c->WithRank比较第一个参数B的rank( = 2)与第二参数2相等,从而创建了一个rank=2的ShapeHandle b输出
  • 不考虑矩阵转置的情况,即transpose_a 和 transpose_b 都是false
    • c->Dim(a, 0)取出a的第0维(行)作为output_rows ( = 3 )
    • c->Dim(b, 1)取出b的第1维(列)作为output_cols ( = 5 )
  • 验证矩阵乘的合法性,即A的列 = B的行
    • c->Dim(a, 1)取出a的第1维(列)作为inner_a ( = 4 )
    • c->Dim(b, 0)取出b的第0维(行)作为inner_b ( = 4 )
    • 通过c->Merge比较inner_a = inner_b 即A的列 = B的行
    • c->Matrix创建一个新的ShapeHandle (rank = 2,Dimension分别为3和5)返回
  • c->set_output将Matrix创建的ShapeHandle作为output形状,至此完成了MatMul算子的输入与输出形状推断

 

 

 

 

 

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值