昨天晚上复现CPC代码过程中遇到了以下代码:
当时一看好家伙,张量索引里面套三维张量,并且通过索引居然还将原本的三维张量变为了四维张量,闻所未闻!随即主观认为是论文作者代码不太严谨(事实是冤枉了人家),通过后续验证发现,代码居然能够跑得通。。。。。。
最开始百思不得其解,之后便咨询了师兄并且查阅了一些博客,才知道不应过分关注维度而忽略索引张量中具体值的含义。
首先,降低维度把问题简单化:
将两个二维张量row_idx和col_idx作为二维全零张量mask的索引,并将其索引对应的部分改为1:
最后输出mask可以得到:
可以看出代码将原第3行的第2,3,4列,第二行的第1,3,2列以及第1行的第0,1,3列所对应的0修改为1,此时已经意思已经很明显了,即通过将row_idx映射到col的对应维度上,并通过其组成的索引值索引原张量。
注:row_idx和col_idx的维度务必相同,例如row_idx.size()=(256,1,1),col_idx.size()=(256,58,64),不然无法映射。
那在这个基础上深入一点:
将row_idx与col_idx的大小分别改为(3,1,1)与(3,2,2),并将其作为mask索引得到remask:
得到结果为:
可见mask通过上述索引,维度从(4,5)变为了(3,3,2),即通过索引张量row_idx和col_idx张量对应维度映射,将row_idx中的3值映射到[[1,2], [2,3], [3,4]],2值映射到[[1,1], [2,3], [2,3]],1值映射到[[0,1], [1,2], [3,4]]之上,总共3*3*2个值,因此remask的维度变为(3,3,2),实现了维度的扩张。