VTA中的数据精度变化
数据精度在何处定义:
在tvm-main\3rdparty\vta-hw\hardware\xilinx路径下的Makefile是VTA的构建过程。
VTA中的数据精度是由vta_config.json文件指定的,如下图。
Make过程中会调用vta_config.py,该脚本文件会调用pkg_config()函数,该函数会调用pkg_config.py中的PkgConfig(),PkgConfig()会根据vta_config.json文件对如下图所示变量进行赋值,
这些变量稍作处理(加“VTA_”,如下图)会在vta.h和vta.cc中使用。
数据精度的变化和方式:
涉及到数据精度变化的操作有compute中的gemm、ALU,所有的输入、权重和输出采用的均是8bit数据,所有精度变化均采用的是直接截取低8bit的形式。
gemm中数据的位宽如下图所示(黄色加深):
如上代码中白色标注处,
1:其中每两个8bit的输入i_elem和权重w_elem相乘的结果存入17bit的mul_T prod_dsp中,
2:随后17bit的数据被累加(累加VTA_BLOCK_IN=16次,对输入的16维向量和权重的列向量求向量乘法)到21bit的sum_T tmp中,
3:随后该数据被外层循环(VTA_BLOCK_OUT=16次)累加到32bit的acc_T accum中,
4:然后该数据被写回到累加缓存a_tensor中,
5:最后使用了hls中的range,获取了低8bit数据,将其写回到了输出缓存o_tensor中
如下是查询到的range操作的含义,该操作与C++中的range有所不同,为hls扩展而来:
alu中数据的位宽如下图所示(黄色加深):
在alu操作中有三种类型分别为:Compute Min/Max、Compute Sum和Compute Shift Right。
Compute Min/Max中计算结果取低8bit写回到输出缓存o_tensor中。
Compute Sum中计算结果取低8bit写回到输出缓存o_tensor中。
Compute Shift Right中会对src_0数据进行shft_by个右移位,shft_by为src_1的低5bit,也就是最大移位值为,5b’11111=5d’31,随后数据被取低8bit给到输出缓存。
总结:需要注意的是在给到输出缓存数据时候,全部对高位宽数据取低8bit,gemm中数据精度变化顺序为8bit->17bit->21bit->32bit,并非直接8bit->32bit,其他指令均未涉及到精度变化。
由于格式问题,以下是本文章的pdf版本:VTA中的数据精度变化