Auto-scheduling 的算子优化

与依赖手动模板定义搜索空间的基于模板的AutoVM不同,auto scheduler不需要任何模板。用户只需编写计算声明,无需任何调度命令或模板。自动调度程序可以自动生成一个较大的搜索空间,并在该空间中找到一个好的调度。

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler

定义矩阵相乘

@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
    A = te.placeholder((N, L), name="A", dtype=dtype)
    B = te.placeholder((L, M), name="B", dtype=dtype)
    C = te.placeholder((N, M), name="C", dtype=dtype)

    k = te.reduce_axis((0, L), name="k")
    matmul = te.compute(
        (N, M),
        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
        name="matmul",
        attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B
    )
    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")

    return [A, B, C, out]

创建搜索任务

target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

输出:

Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])

设置 Auto-Scheduler 的参数

  • num_measure_trials:可以在搜索过程中使用的测量试验数,实际工程中通常设置为 1000
  • RecordToFile:将测试记录到 matmul.json。测试记录可用于查询历史最佳记录、恢复搜索以及稍后进行更多分析
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

运行搜索

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

输出:

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Generate Sketches		#s: 3
Sample Initial Population	#s: 2002	fail_ct: 7	Time elapsed: 0.98
GA Iter: 0	Max score: 0.9994	Min score: 0.9368	#Pop: 128	#M+: 0	#M-: 0
GA Iter: 4	Max score: 0.9999	Min score: 0.9862	#Pop: 128	#M+: 1380	#M-: 74
EvolutionarySearch		#s: 128	Time elapsed: 4.04
----------------------------------------------------------------------
------------------------------  [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........
**********
==================================================
No: 1	GFLOPS: 153.35 / 153.35	results: MeasureResult(cost:[0.0140], error_no:0, all_cost:0.91, Tstamp:1653631216.48)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 64
parallel i.0@j.0@i.1@j.1@ (0,128)
  for k.0 (0,512)
    for i.2 (0,2)
      for j.2 (0,256)
        for k.1 (0,2)
          for i.3 (0,16)
            matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 2	GFLOPS: 28.68 / 153.35	results: MeasureResult(cost:[0.0749], error_no:0, all_cost:0.85, Tstamp:1653631217.13)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,8)
  matmul auto_unroll: 64
  for i.1 (0,4)
    for j.1 (0,32)
      for k.0 (0,1024)
        for i.2 (0,64)
          for i.3 (0,4)
            vectorize j.3 (0,4)
              matmul = ...
  for i.1 (0,1024)
    for j.1 (0,128)
      out = ...

==================================================
No: 3	GFLOPS: 167.54 / 167.54	results: MeasureResult(cost:[0.0128], error_no:0, all_cost:2.88, Tstamp:1653631217.65)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 512
parallel i.0@j.0@i.1@j.1@ (0,128)
  for k.0 (0,512)
    for i.2 (0,128)
      for j.2 (0,32)
        for k.1 (0,2)
          for i.3 (0,2)
            matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 4	GFLOPS: 228.06 / 228.06	results: MeasureResult(cost:[0.0094], error_no:0, all_cost:1.05, Tstamp:1653631218.19)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,2048)
  for k.0 (0,512)
    for i.2 (0,8)
      for j.2 (0,4)
        for k.1 (0,2)
          for j.3 (0,16)
            matmul = ...
  for i.2 (0,8)
    for j.2 (0,64)
      out = ...

==================================================
No: 5	GFLOPS: 239.78 / 239.78	results: MeasureResult(cost:[0.0090], error_no:0, all_cost:1.03, Tstamp:1653631218.83)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,1024)
  matmul auto_unroll: 512
  for i.1 (0,2)
    for j.1 (0,32)
      for k.0 (0,512)
        for k.1 (0,2)
          for i.3 (0,16)
            matmul = ...
  for i.1 (0,32)
    for j.1 (0,32)
      out = ...

==================================================
No: 6	GFLOPS: 60.33 / 239.78	results: MeasureResult(cost:[0.0356], error_no:0, all_cost:1.80, Tstamp:1653631219.34)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 512
parallel i.0@j.0@i.1@j.1@ (0,2048)
  for k.0 (0,16)
    for i.2 (0,512)
      for k.1 (0,64)
        matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 7	GFLOPS: 51.90 / 239.78	results: MeasureResult(cost:[0.0414], error_no:0, all_cost:0.68, Tstamp:1653631219.85)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,2048)
  for k.0 (0,64)
    for i.2 (0,4)
      for j.2 (0,16)
        for k.1 (0,16)
          for i.3 (0,8)
            matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 8	GFLOPS: 17.02 / 239.78	results: MeasureResult(cost:[0.1262], error_no:0, all_cost:1.04, Tstamp:1653631220.72)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,1024)
  for k.0 (0,64)
    for i.2 (0,8)
      for j.2 (0,8)
        for k.1 (0,16)
          for i.3 (0,16)
            matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 9	GFLOPS: 350.08 / 350.08	results: MeasureResult(cost:[0.0061], error_no:0, all_cost:0.60, Tstamp:1653631221.20)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 64
parallel i.0@j.0@ (0,8192)
  for i.1 (0,4)
    for j.1 (0,4)
      for k.0 (0,16)
        for k.1 (0,64)
          for i.3 (0,2)
            vectorize j.3 (0,4)
              matmul = ...
parallel i@j@ (0,1048576)
  out = ...

==================================================
No: 10	GFLOPS: 485.02 / 485.02	results: MeasureResult(cost:[0.0044], error_no:0, all_cost:0.74, Tstamp:1653631221.79)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,4096)
  for k.0 (0,1024)
    for i.2 (0,2)
      for j.2 (0,8)
        vectorize j.3 (0,16)
          matmul = ...
parallel i@j@ (0,1048576)
  out = ...

Time elapsed for measurement: 9.27 s
----------------------------------------------------------------------
------------------------------  [ Done ]
----------------------------------------------------------------------

查看优化后的调度

完成自动调优之后,可以察看代码,优化可能包含:多级分块,布局重排,并行化,向量化,循环展开,算子融合等:

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

输出:

Lowered TIR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
             B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
             C: Buffer(C_2: Pointer(float32), float32, [1048576], []),
             out: Buffer(out_2: Pointer(float32), float32, [1048576], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], []), out_1: out_3: Buffer(out_2, float32, [1024, 1024], [])} {
  allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global;
  allocate(matmul: Pointer(global float32), float32, [1048576]), storage_scope = global {
    for (ax0.ax1.fused.ax2.fused: int32, 0, 8192) "parallel" {
      for (ax3: int32, 0, 8) {
        for (ax5: int32, 0, 16) {
          let cse_var_1: int32 = (ax3*16)
          auto_scheduler_layout_transform_1: Buffer(auto_scheduler_layout_transform, float32, [1048576], [])[(((ax0.ax1.fused.ax2.fused*128) + cse_var_1) + ax5)] = B[((((floormod(ax0.ax1.fused.ax2.fused, 1024)*1024) + (floordiv(ax0.ax1.fused.ax2.fused, 1024)*128)) + cse_var_1) + ax5)]
        }
      }
    }
    for (i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused: int32, 0, 4096) "parallel" {
      for (i.outer.inner.init: int32, 0, 2) {
        for (j.outer.inner.init: int32, 0, 8) {
          matmul_1: Buffer(matmul, float32, [1048576], [])[ramp(((((floordiv(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*2048) + (i.outer.inner.init*1024)) + (floormod(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*128)) + (j.outer.inner.init*16)), 1, 16)] = broadcast(0f32, 16)
        }
      }
      for (k.outer: int32, 0, 1024) {
        for (i.outer.inner: int32, 0, 2) {
          for (j.outer.inner: int32, 0, 8) {
            let cse_var_5: int32 = floormod(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)
            let cse_var_4: int32 = (j.outer.inner*16)
            let cse_var_3: int32 = ((floordiv(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*2048) + (i.outer.inner*1024))
            let cse_var_2: int32 = ((cse_var_3 + (cse_var_5*128)) + cse_var_4)
            matmul_1[ramp(cse_var_2, 1, 16)] = (matmul_1[ramp(cse_var_2, 1, 16)] + (broadcast(A[(cse_var_3 + k.outer)], 16)*auto_scheduler_layout_transform_1[ramp((((cse_var_5*131072) + (k.outer*128)) + cse_var_4), 1, 16)]))
          }
        }
      }
    }
    for (i.j.fused: int32, 0, 1048576) "parallel" {
      out[i.j.fused] = (matmul_1[i.j.fused] + C[i.j.fused])
    }
  }
}

检查正确性和评估性能

func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np

dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)

# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)

输出:

Execution time of this operator: 4.906 ms

使用记录文件

下面的例子是从记录文件中获取最佳的调度,然后打印等价的 python 调度 API,从这里可以学到如何调试和自动调优的行为:

print("Equivalent python schedule:")
print(task.print_best(log_file))

输出:

Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=1)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=2)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=1)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=16)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=8)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=8)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=1)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused = s[matmul].fuse(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i)
s[matmul].parallel(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused)
out_i_j_fused = s[out].fuse(out_i, out_j)
s[out].parallel(out_i_j_fused)
s[matmul].pragma(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused, "auto_unroll_max_step", 0)
s[matmul].pragma(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)

一个更复杂的例子是恢复搜索。在这种情况下,我们需要自己创建搜索策略和代价模型,并使用日志文件恢复搜索策略和代价模型的状态。在下面的示例中,我们恢复状态并进行5次以上的试验。

def resume_search(task, log_file):
    print("Resume search:")
    cost_model = auto_scheduler.XGBModel()
    cost_model.update_from_file(log_file)
    search_policy = auto_scheduler.SketchPolicy(
        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
    )
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
    )
    task.tune(tune_option, search_policy=search_policy)


resume_search(task, log_file)

输出:

# cannot import name 'EarlyStopException' from 'xgboost.core' 
# pip3 uninstall xgboost
# pip3 install xgboost==1.5.0 -i https://pypi.douban.com/simple/

Resume search:
/root/anaconda3/envs/py37/lib/python3.7/site-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
  warnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值