pytorch 的 gather 和 scatter 这两个方法,理解起来是有一定难度的,正因如此,网上有很多博客介绍它们,但往往作者讲得并不清晰。导致读者即使当时看懂了,过段时间又会忘记。经过第N遍谷歌这两个方法的用法,本人终于想到一种类比,可以方便地理解并牢固地记住它们。
类比
假设某校高三有四个班级,各班花名册可以总结为下表:
现在他们要报高考志愿,假设总共有五个备选高校:清华、北大、复旦、上交、浙大,序号依次为 0-4。各班的报名状况为:
如果我想知道有哪些人报了清华大学,他们分别在哪个班,上面的表格是不够直观的。如果生成一个把大学作为行、把班级作为列的表格,则可以清晰地看到各个大学在各班的报名情况:
读到这里,恭喜您,已经完成了一次 pytorch 中的 scatter 操作!
scatter
首先摘录 pytorch 的官方文档:
scatter_(dim, index, src, reduce=None) → Tensor
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
不要被它的例子里面复杂的下标索引吓到。我们先从简单的二维情况分析,把上面的例子第一行改成二维的:
self[index[i][j]][j] = src[i][j] # if dim == 0
然后,类比报志愿的问题,这里 src 是表一,index 是表二,而 self 则是表三。i 是学生在该班级内部的编号(所在行),j 是班级编号,注意 i 和 j 都从 0 开始。
那么,src[i][j] 表示的是 j 班的第 i 个同学。根据表一,src[0][0] 就是小红,src[1][2] 是小强。
index[i][j] 表示的是 j 班 i 号同学所报大学的编号。因此,显而易见的是 index 的形状应该和 src 完全一致(严格来讲,其实只需 index 各维度尺寸小于等于 src 各维度对应尺寸,index 在 dim 之外的维度的尺寸,小于等于 self 在dim 之外的尺寸),同时,index 表中最大的数,不应超出可选大学的最大下标,也就是 self 的末行下标。
gather 所做的操作就是,先查出 index[i][j]:找到第 j 班 i 同学所报的大学的编号,该编号就是需要修改的表三的行下标
;由于该同学所在的班级为 j 班,因此需要修改的表三的列下标为 j;确定了需要修改的表三的行、列下标,接下来需要确定的是修改的值,在我们的类比中,src[i][j] 是该同学的名字,因此把名字填入 self 中对应位置即可。
下面的代码演示了上面的类比过程:
import torch
# 生成表一,存的是各班学生姓名,用序号代替,序号按列(班级)依次递增
src = torch.arange(8).view(4, 2).transpose(0, 1).type(torch.float32)
print('src=')
print(src)
# tensor([[0, 2, 4, 6],
# [1, 3, 5, 7]])
# 生成表二,存的是各班各同学所报大学的下标
index = torch.LongTensor([[0, 1, 4, 3],[2, 3, 0, 1]])
print('index=')
print(index)
# tensor([[0, 1, 4, 3],
# [2, 3, 0, 1]])
#
# 初始化表三,即各校报名情况
tgt = torch.ones(5, 4)*-1
print('tgt=')
print(tgt)
# tensor([[-1., -1., -1., -1.],
# [-1., -1., -1., -1.],
# [-1., -1., -1., -1.],
# [-1., -1., -1., -1.],
# [-1., -1., -1., -1.]])
# dim=0 表示 index 存储的数字为 tgt 的行号,src 存的是学生名字
# scatter 将学生名字填入 tgt 对应的位置
tgt.scatter_(dim=0, index=index, src=src)
print('tgt=')
print(tgt)
# tensor([[ 0., -1., 5., -1.],
# [-1., 2., -1., 7.],
# [ 1., -1., -1., -1.],
# [-1., 3., -1., 6.],
# [-1., -1., 4., -1.]])
如果把代码中的 dim=0 改成 dim=1,scatter 的赋值规则变为:
self[i][index[i][j]] = src[i][j] # if dim == 1
回到上面的类比,index 存储的是 self 的列号,即大学的序号,i 是班级序号,j 是学生在该班的序号。最后得到的 self 为表三的转置。
至于 3D tensor,则可以扩展我们的类比,增加一个高中学校的编号:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
index[i][j][k]
:k 号高中,j 班,i 号学生所报高校编号
src[i][j][k]
:k 号高中,j 班,i号学生的名字
self[index[i][j][k]][j][k]
:index[i][j][k] 号高校所招的 k 号 高中,j 班学生的名字
小结
scatter_(dim, index, src, reduce=None) → Tensor
中,index 存储的是 self 中 dim 维的下标。scatter 所做的事情是,把 src 中 index 所覆盖的元素(想象把 index 表盖到 src 上),填充到 self 中。填充位置的下标如何确定呢?其中 dim 维下标由该元素对应的 index 确定,而其余维度下标保持和 index 中一致。
gather
摘录 pytorch 的文档:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
上面的类比还可以继续用。这里的 input 为表三,index 为表二,out为表一。即:scatter 是由表一和表二得到表三,而gather则是由表三和表二得到表一。gather 是 scatter 的逆运算。
gather 所做的事就是:知道了各班各同学报的学校下标 (index),也知道各学校在各班的招生情况了(input),反推哪个班有哪个同学(out)。
我们继续上一小节的代码:
input = torch.gather(tgt, 0, index)
print('input=')
print(input)
# tensor([[0., 2., 4., 6.],
# [1., 3., 5., 7.]])
果然得到了 scatter 中的 src。
scatter 的目标是修改现有矩阵,gather 是根据 index 从现有矩阵提取元素,构成一个新矩阵。新矩阵的形状与 index 一致,index 存储的元素决定了所提取的元素在 input 的 dim-维上的下标,dim-维之外的维度下标与 index 所存元素对应的下标一致。
总结
本文通过类比的方式讲了 scatter 和 gather 的用法,希望能够对读者有所帮助。对于里面可能存在的错漏的地方,望大神在评论区不吝赐教。