这是有关如何在TVM中进行循环计算的入门资料。递归计算是神经网络中的一种典型模式。
from __future__ import absolute_import,print_function
import tvm
import tvm.testing
from tvm import te
import numpy as np
TVM支持使用扫描运算符来描述符号循环。以下扫描运算将计算X列上的总和。
扫描在张量的最大维度上进行。 s_state是一个占位符,描述了扫描的过渡状态。 s_init描述了如何初始化前k个时间步。由于s_init的第一维为1,因此它描述了我们如何在第一时间步初始化状态。
s_update描述了如何在时间步t更新值。更新值可以通过状态占位符返回到先前时间步的值。请注意,虽然s_state在当前或以后的时间步中进行引用是无效的。
扫描将获取状态占位符,初始值和更新描述。还建议(尽管不是必需的)列出扫描单元的输入。扫描的结果是张量,给出s_state时域上更新后的结果。
m=te.var("m")
n=te.var("n")
X=te.placeholder((m,n),name="X")
s_state=te.placeholder((m,n))
s_init=te.compute((1,n),lambda _,i: X[0,i])
s_update=te.compute((m,n),lambda t,i: s_state[t-1,i]+X[t,i])
s_scan=te.scan(s_init,s_update,s_state,inputs=[X])
调度扫描单元
我们可以通过分别计划更新和初始化部分来计划扫描的主体。请注意,安排