刚看到这两个函数的时候不懂,到处找解说,最后发现还是官方文档说的清楚一点
在此记录一下,方便自己日后回忆
这是官方文档
torch.gather — PyTorch 2.1 documentation
torch.Tensor.scatter_ — PyTorch 2.1 documentation
1. gather
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
这是官方的定义,其中
input:表示输入的tensor
dim:表示索引的轴(这里的说法有点难理解,后面只能通过定义理解)
index:要收集的元素索引
官方的说法只看文字描述不好理解,我们通过(官方的)例子来说明:
t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1], [ 4, 3]])
gather操作实际是按照某种规律,将input的元素重新排列
dim和index两者一起构成寻找元素的索引,假设输出为output,规则如下:
out[i][j] = input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1
可以这样理解,输出output和输入input的元素存在某种映射关系
只要掌握之间的转换关系,就可以算出output
以我们的具体例子来说,是这样的:
output[0][0] = input[0][0] # 坐标为(0,0),index[0][0] = 0,dim = 1, 因此替换为:(0,0)->(0,0)
output[0][1] = input[0][0] # 坐标为(0,1),index[0][1] = 0,dim = 1, 因此替换为:(0,1)->(0,0)
output[1][0] = input[1][1] # 坐标为(1,0),index[1][0] = 1,dim = 1, 因此替换为:(1,0)->(0,1)
output[1][1] = input[1][0] # 坐标为(1,1),index[1][1] = 0,dim = 1, 因此替换为:(1,1)->(1,0)
现在可以简单的计算以下程序的结果,这里令dim = 0:
t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 0, torch.tensor([[0, 0], [1, 0]]))
tensor([[1, 2], [3, 2]])
2. scatter_
这个函数与gather十分相似
Tensor.scatter_(dim, index, src, reduce=None)
变换规则为:
self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1
self为output,src相当于gather中的输入,与gather的变换规则几乎相同。
可自行参照官方代码尝试计算:
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]])
只要按照转换规则,就可以计算出函数的结果