【xla】二、【构图阶段】MarkForCompilationPassImpl

一、主要函数:

  Initialize()
  1433   TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); // 尽量多的收缩节点到cluster内
  1434   TF_RETURN_IF_ERROR(CreateClusters()); // 更改图中节点的属性(更改device或打xla标记)
  1435   TF_RETURN_IF_ERROR(DumpDebugInfo()); // LOG 相关

二、Initialize

FindCompilationCandidates 遍历图挑出可优化的node(本文章第七章)
CreateCycleDetectionGraph 创建一个更加适合修改图的数据结构:GraphCycles,该数据结构可以更加方便的插入边、删除变、插入点等操作。
BuildInitialClusterSet 为每个Op都创建一个新的cluster。然后简历nodeId到cluster的映射:cluster_for_node_, 之后的所有操作都是基于cluster和graph cycles操作。

三、RunEdgeContractionLoop:

phase0: IsShapeConsumerOp 尽量将x 跟 x的shape 放在一个cluster内,X -> Y -> Size -> Z 尽量将Y跟Size cluster到一起,因为Y的输出有可能是个非常大的tensor,大tensor的传递比较困难。
	这里的shape指:Shape,Rank,Size
phase1: IsSinkLike IsScalarIntegerResourceOperation 对NoOp和i++这类op做特殊处理。
phase2: 收缩任何可以收缩的边,来得到一个最大的cluster。
check 再次尝试收缩,如果跟phase2结果相等即正确。

四、CreateClusters

DumpGraphToFile: 通过graph保存优化前的图。
遍历CompilationCandidates(由Initialize::FindCompilationCandidates生成)中的每个node。
	1. 为每个node添加属性:xla_name("cluster_x")和kXlaAlreadyClustered。

五、Cluster合并逻辑:

  1. cluster(算子)合并失败的条件:

在MarkForCompilationPassImpl::TryToContractEdge中:

  1. 两个cluster的生命周期不一样。

  2. 两个cluster的设备不一致。

  3. 两个cluster的scope不一致(不能合并pipeline类似的算子)。

  4. 不能超过cluster的最大算子个数(max默认:std::numeric_limits<int32>::max(), min默认:4)。

  5. 不能跨设备依赖。

  6. 可能破坏资源变量语义(暂时不会出现)。

六、候选算子(只有在init阶段成为候选算子,才有机会进入整个jit的优化过程)过滤条件:

  1. 根据算子ID对算子进行排序。

  2. 通过tf_xla_ops_to_cluster决定白名单内容。

    1. Fusible: 全部白名单table中的算子都放入白名单list中。

    2. 白名单table中包含了用户传入的集合名称:把s对应的算子都加入白名单list中。

    3. 用户直接提供算子名称:把算子加入到白名单list中。

  3. 拿到通过XLA注册器注册的所有算子。

  4. 检查:如果白名单里的算子不在xla注册器里将直接失败。

  5. 检查:如果属性里或者这个node的调用者带着不允许xla编译(_XlaCompile),则该op不能作为候选者。

  6. 检查:如果算子的设备不是一个已知的可转换成XLA设备的算子(目前仅支持DEVICE_CPU和DEVICE_GPU),则该op不能作为候选者。

  7. 检查:如果该算子的某项DeviceRegistration检查不符合,则该op不能作为候选者。

    1. 如果是source/sink节点。否

    2. 如果是arg/retval节点。否

    3. 如果在op的attr中找到:_scoped_allocator或者_forward_from否

    4. 如果FunctionCall节点:否

    5. 如果该节点没有xla kernel: 否

      1. 如果node类型包涵:SymbolicGradient 或者const型的string时:否

      2. HasForwardedRefInput:否

      3. FindKernelDef:否

    6. 如果是不能compilable的while node: 否

    7. 如果是不能compilable的if node: 否

    8. 如果是allow_stateful_rng_ops且IsStatefulRandomOp:否

    9. 剩余跟OpFilter一致。参考第九章。

  8. 检查:白名单中包涵该op,该op不能作为候选者。

  9. 检查:如果该op是stack或者array等对资源的op,则该op不能作为候选者。

  10. 检查:如果loop中的identity将不会被cluster。

  11. 如果能通过以上检查,则将算子加入到候选算子中。

七、白名单table:

          // Unary
          {"PW",
           {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
            "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
            "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
            "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", "Rsqrt",
            "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", "Square",
            "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Lgamma", "Digamma",
            // Binary
            "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
            "MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
            "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
            "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
            "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
            "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
            "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
            "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
            // Others
            "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
            "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
            "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
            "SelectV2", "Transpose", "ConjugateTranspose",
            "_UnaryOpsComposition",
            // The following 4 operations are converted to identity
            "PlaceholderWithDefault", "PreventGradient", "StopGradient",
            "Snapshot"}},
          // clang-format off
    {"RED",
     {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
          // clang-format on
          {"PWRED",
           {"ArgMax", "ArgMin", "DiagPart", "Softmax",
            "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
          {"REDUCEWINDOW",
           {"ArgMax", "ArgMin", "DiagPart", "Softmax",
            "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
          {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
          {"BN",
           {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
            "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
            "FusedBatchNormGradV3"}},
          {"SORT", {"TopKV2"}},  // XLA version much faster then TF version.
          {"MISC",
           // clang-format off
     {"BroadcastTo", "ExpandDims", "Fill", "NoOp",
      "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
      "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
      "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
      "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
      "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
      "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
      "Tile", "Transpose", "InvertPermutation", "Unpack"}}};
  // clang-format on

八、DeviceRegistration:

    // Describes how to compile operators assigned to a device.
		// 该类主要描述如何给一个compile算子分配device。
    struct DeviceRegistration {
      // The name of the an XLA compilation device to use to compile code.
      // device name
      string compilation_device_name;

      // When should we autocluster operators assigned to this device?
      // 分配device的时间。
      AutoclusteringPolicy autoclustering_policy;

      // If we should ignore the resource variable memory model when clustering
      // resource variable reads and writes placed on this device.
      // 对资源读写型op 是否放置在该设备上。
      bool cluster_resource_variable_ops_unsafely = false;

      // If we should auto-cluster Stack operations placed on this device.
      // 对 stack op是否放置在该设备上。
      bool cluster_stack_ops = false;

      // If we should auto-cluster TensorArray operations placed on this device.
      // 对 tensor array op 是否放置在该设备上。
      bool cluster_tensor_array_ops = false;

      // If we should auto-cluster stateful RNG operations placed on this device.
      // Stateful RNG semantics are not properly supported by XLA so it is not
      // necessarily correct to auto-cluster stateful RNG ops in general.
      // 对有状态的rng操作是否放置在该设备上。
      bool cluster_stateful_rng_ops = false;

      // If we should auto-cluster ControlTrigger operations placed on this
      // device.  ControlTrigger operations are not necessarily safe to cluster
      // since they affect deadness (a dead ControlTrigger produces a live
      // output).
      // 是否自动聚类 ControlTrigger op 在这个设备上。
      bool cluster_control_trigger = false;

      // If we should cluster Assert and CheckNumerics by eliding them (XLA does
      // not natively support Assert or CheckNumerics).
      // 是否对Assert和CheckNumerics聚类,xla不支持这两个算子。
      bool elide_assert_and_checknumerics = false;

      // If we should cluster operations returning DT_VARIANT.
      // VARIANT 类型是否聚类。
      bool cluster_variant_ops = false;

      // Whether ops known to be slow should be auto-clustered.
      // 已知有可能变慢的操作是否需要聚类。
      bool cluster_slow_ops = false;

      // Whether ops known to have numerical accuracy issues should be
      // auto-clustered.
      // 有可能会影响精度问题的操作是否需要聚类。
      bool cluster_inaccurate_ops = false;
    }

九、Device上的filter

  1. XLA_CPU_DEVICE:

>>  registration.compilation_device_name = DEVICE_CPU_XLA_JIT;
    registration.autoclustering_policy =
        compile_on_demand
>>          ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
>>          : XlaOpRegistry::AutoclusteringPolicy::kAlways;
    registration.cluster_resource_variable_ops_unsafely = true;
    registration.cluster_stack_ops = false;
    registration.cluster_tensor_array_ops = true;
    registration.cluster_stateful_rng_ops = true;
    registration.cluster_control_trigger = true;
    registration.elide_assert_and_checknumerics = true;
    registration.cluster_variant_ops = true;
    registration.cluster_slow_ops = true;
    registration.cluster_inaccurate_ops = true;

2. XLA_GPU_DEVICE:

>>  registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
    registration.autoclustering_policy =
>>      XlaOpRegistry::AutoclusteringPolicy::kAlways;
    registration.cluster_resource_variable_ops_unsafely = true;
    registration.cluster_stack_ops = false;
    registration.cluster_tensor_array_ops = true;
    registration.cluster_stateful_rng_ops = true;
    registration.cluster_control_trigger = true;
    registration.elide_assert_and_checknumerics = true;
    registration.cluster_variant_ops = true;
    registration.cluster_slow_ops = true;
    registration.cluster_inaccurate_ops = true;
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值