与依赖手动模板定义搜索空间的基于模板的AutoVM不同,auto scheduler不需要任何模板。用户只需编写计算声明,无需任何调度命令或模板。自动调度程序可以自动生成一个较大的搜索空间,并在该空间中找到一个好的调度。
import os
import numpy as np
import tvm
from tvm import te, auto_scheduler
定义矩阵相乘
@auto_scheduler.register_workload # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
创建搜索任务
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)
输出:
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])
设置 Auto-Scheduler 的参数
num_measure_trials
:可以在搜索过程中使用的测量试验数,实际工程中通常设置为 1000RecordToFile
:将测试记录到matmul.json
。测试记录可用于查询历史最佳记录、恢复搜索以及稍后进行更多分析
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
运行搜索
# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)
输出:
----------------------------------------------------------------------
------------------------------ [ Search ]
----------------------------------------------------------------------
Generate Sketches #s: 3
Sample Initial Population #s: 2002 fail_ct: 7 Time elapsed: 0.98
GA Iter: 0 Max score: 0.9994 Min score: 0.9368 #Pop: 128 #M+: 0 #M-: 0
GA Iter: 4 Max score: 0.9999 Min score: 0.9862 #Pop: 128 #M+: 1380 #M-: 74
EvolutionarySearch #s: 128 Time elapsed: 4.04
----------------------------------------------------------------------
------------------------------ [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........
**********
==================================================
No: 1 GFLOPS: 153.35 / 153.35 results: MeasureResult(cost:[0.0140], error_no:0, all_cost:0.91, Tstamp:1653631216.48)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 64
parallel i.0@j.0@i.1@j.1@ (0,128)
for k.0 (0,512)
for i.2 (0,2)
for j.2 (0,256)
for k.1 (0,2)
for i.3 (0,16)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 2 GFLOPS: 28.68 / 153.35 results: MeasureResult(cost:[0.0749], error_no:0, all_cost:0.85, Tstamp:1653631217.13)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,8)
matmul auto_unroll: 64
for i.1 (0,4)
for j.1 (0,32)
for k.0 (0,1024)
for i.2 (0,64)
for i.3 (0,4)
vectorize j.3 (0,4)
matmul = ...
for i.1 (0,1024)
for j.1 (0,128)
out = ...
==================================================
No: 3 GFLOPS: 167.54 / 167.54 results: MeasureResult(cost:[0.0128], error_no:0, all_cost:2.88, Tstamp:1653631217.65)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 512
parallel i.0@j.0@i.1@j.1@ (0,128)
for k.0 (0,512)
for i.2 (0,128)
for j.2 (0,32)
for k.1 (0,2)
for i.3 (0,2)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 4 GFLOPS: 228.06 / 228.06 results: MeasureResult(cost:[0.0094], error_no:0, all_cost:1.05, Tstamp:1653631218.19)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,2048)
for k.0 (0,512)
for i.2 (0,8)
for j.2 (0,4)
for k.1 (0,2)
for j.3 (0,16)
matmul = ...
for i.2 (0,8)
for j.2 (0,64)
out = ...
==================================================
No: 5 GFLOPS: 239.78 / 239.78 results: MeasureResult(cost:[0.0090], error_no:0, all_cost:1.03, Tstamp:1653631218.83)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,1024)
matmul auto_unroll: 512
for i.1 (0,2)
for j.1 (0,32)
for k.0 (0,512)
for k.1 (0,2)
for i.3 (0,16)
matmul = ...
for i.1 (0,32)
for j.1 (0,32)
out = ...
==================================================
No: 6 GFLOPS: 60.33 / 239.78 results: MeasureResult(cost:[0.0356], error_no:0, all_cost:1.80, Tstamp:1653631219.34)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 512
parallel i.0@j.0@i.1@j.1@ (0,2048)
for k.0 (0,16)
for i.2 (0,512)
for k.1 (0,64)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 7 GFLOPS: 51.90 / 239.78 results: MeasureResult(cost:[0.0414], error_no:0, all_cost:0.68, Tstamp:1653631219.85)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,2048)
for k.0 (0,64)
for i.2 (0,4)
for j.2 (0,16)
for k.1 (0,16)
for i.3 (0,8)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 8 GFLOPS: 17.02 / 239.78 results: MeasureResult(cost:[0.1262], error_no:0, all_cost:1.04, Tstamp:1653631220.72)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,1024)
for k.0 (0,64)
for i.2 (0,8)
for j.2 (0,8)
for k.1 (0,16)
for i.3 (0,16)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 9 GFLOPS: 350.08 / 350.08 results: MeasureResult(cost:[0.0061], error_no:0, all_cost:0.60, Tstamp:1653631221.20)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 64
parallel i.0@j.0@ (0,8192)
for i.1 (0,4)
for j.1 (0,4)
for k.0 (0,16)
for k.1 (0,64)
for i.3 (0,2)
vectorize j.3 (0,4)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
==================================================
No: 10 GFLOPS: 485.02 / 485.02 results: MeasureResult(cost:[0.0044], error_no:0, all_cost:0.74, Tstamp:1653631221.79)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,4096)
for k.0 (0,1024)
for i.2 (0,2)
for j.2 (0,8)
vectorize j.3 (0,16)
matmul = ...
parallel i@j@ (0,1048576)
out = ...
Time elapsed for measurement: 9.27 s
----------------------------------------------------------------------
------------------------------ [ Done ]
----------------------------------------------------------------------
查看优化后的调度
完成自动调优之后,可以察看代码,优化可能包含:多级分块,布局重排,并行化,向量化,循环展开,算子融合等:
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
输出:
Lowered TIR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
C: Buffer(C_2: Pointer(float32), float32, [1048576], []),
out: Buffer(out_2: Pointer(float32), float32, [1048576], [])}
buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out}
preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], []), out_1: out_3: Buffer(out_2, float32, [1024, 1024], [])} {
allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global;
allocate(matmul: Pointer(global float32), float32, [1048576]), storage_scope = global {
for (ax0.ax1.fused.ax2.fused: int32, 0, 8192) "parallel" {
for (ax3: int32, 0, 8) {
for (ax5: int32, 0, 16) {
let cse_var_1: int32 = (ax3*16)
auto_scheduler_layout_transform_1: Buffer(auto_scheduler_layout_transform, float32, [1048576], [])[(((ax0.ax1.fused.ax2.fused*128) + cse_var_1) + ax5)] = B[((((floormod(ax0.ax1.fused.ax2.fused, 1024)*1024) + (floordiv(ax0.ax1.fused.ax2.fused, 1024)*128)) + cse_var_1) + ax5)]
}
}
}
for (i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused: int32, 0, 4096) "parallel" {
for (i.outer.inner.init: int32, 0, 2) {
for (j.outer.inner.init: int32, 0, 8) {
matmul_1: Buffer(matmul, float32, [1048576], [])[ramp(((((floordiv(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*2048) + (i.outer.inner.init*1024)) + (floormod(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*128)) + (j.outer.inner.init*16)), 1, 16)] = broadcast(0f32, 16)
}
}
for (k.outer: int32, 0, 1024) {
for (i.outer.inner: int32, 0, 2) {
for (j.outer.inner: int32, 0, 8) {
let cse_var_5: int32 = floormod(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)
let cse_var_4: int32 = (j.outer.inner*16)
let cse_var_3: int32 = ((floordiv(i.outer.outer.outer.j.outer.outer.outer.fused.i.outer.outer.inner.fused.j.outer.outer.inner.fused, 8)*2048) + (i.outer.inner*1024))
let cse_var_2: int32 = ((cse_var_3 + (cse_var_5*128)) + cse_var_4)
matmul_1[ramp(cse_var_2, 1, 16)] = (matmul_1[ramp(cse_var_2, 1, 16)] + (broadcast(A[(cse_var_3 + k.outer)], 16)*auto_scheduler_layout_transform_1[ramp((((cse_var_5*131072) + (k.outer*128)) + cse_var_4), 1, 16)]))
}
}
}
}
for (i.j.fused: int32, 0, 1048576) "parallel" {
out[i.j.fused] = (matmul_1[i.j.fused] + C[i.j.fused])
}
}
}
检查正确性和评估性能
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
输出:
Execution time of this operator: 4.906 ms
使用记录文件
下面的例子是从记录文件中获取最佳的调度,然后打印等价的 python 调度 API,从这里可以学到如何调试和自动调优的行为:
print("Equivalent python schedule:")
print(task.print_best(log_file))
输出:
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=1)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=2)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=1)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=16)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=8)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=8)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=1)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused = s[matmul].fuse(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i)
s[matmul].parallel(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused)
out_i_j_fused = s[out].fuse(out_i, out_j)
s[out].parallel(out_i_j_fused)
s[matmul].pragma(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused, "auto_unroll_max_step", 0)
s[matmul].pragma(matmul_i_o_o_o_j_o_o_o_fused_i_o_o_i_fused_j_o_o_i_fused, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
一个更复杂的例子是恢复搜索。在这种情况下,我们需要自己创建搜索策略和代价模型,并使用日志文件恢复搜索策略和代价模型的状态。在下面的示例中,我们恢复状态并进行5次以上的试验。
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
resume_search(task, log_file)
输出:
# cannot import name 'EarlyStopException' from 'xgboost.core'
# pip3 uninstall xgboost
# pip3 install xgboost==1.5.0 -i https://pypi.douban.com/simple/
Resume search:
/root/anaconda3/envs/py37/lib/python3.7/site-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated. See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)