torch.gather 和 torch.scatter_

刚看到这两个函数的时候不懂,到处找解说,最后发现还是官方文档说的清楚一点

在此记录一下,方便自己日后回忆

这是官方文档

torch.gather — PyTorch 2.1 documentation

torch.Tensor.scatter_ — PyTorch 2.1 documentation

1. gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

这是官方的定义,其中

input:表示输入的tensor

dim:表示索引的轴(这里的说法有点难理解,后面只能通过定义理解)

index:要收集的元素索引

官方的说法只看文字描述不好理解,我们通过(官方的)例子来说明:

t = torch.tensor([[1, 2], [3, 4]]) 
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) 
tensor([[ 1, 1], [ 4, 3]])

gather操作实际是按照某种规律,将input的元素重新排列

dim和index两者一起构成寻找元素的索引,假设输出为output,规则如下:

out[i][j] = input[index[i][j]][j] # if dim == 0

out[i][j] = input[i][index[i][j]] # if dim == 1

可以这样理解,输出output和输入input的元素存在某种映射关系

只要掌握之间的转换关系,就可以算出output

以我们的具体例子来说,是这样的:

output[0][0] = input[0][0] # 坐标为(0,0),index[0][0] = 0,dim = 1, 因此替换为:(0,0)->(0,0)

output[0][1] = input[0][0] # 坐标为(0,1),index[0][1] = 0,dim = 1, 因此替换为:(0,1)->(0,0)

output[1][0] = input[1][1] # 坐标为(1,0),index[1][0] = 1,dim = 1, 因此替换为:(1,0)->(0,1)

output[1][1] = input[1][0] # 坐标为(1,1),index[1][1] = 0,dim = 1, 因此替换为:(1,1)->(1,0)

现在可以简单的计算以下程序的结果,这里令dim = 0:

t = torch.tensor([[1, 2], [3, 4]]) 
>>> torch.gather(t, 0, torch.tensor([[0, 0], [1, 0]])) 
tensor([[1, 2], [3, 2]])

2. scatter_

这个函数与gather十分相似

Tensor.scatter_(dim, index, src, reduce=None)

变换规则为:

self[index[i][j]][j] = src[i][j] # if dim == 0

self[i][index[i][j]] = src[i][j] # if dim == 1

self为output,src相当于gather中的输入,与gather的变换规则几乎相同。

可自行参照官方代码尝试计算:

>>> src = torch.arange(1, 11).reshape((2, 5))

>>> src

tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]])

>>> index = torch.tensor([[0, 1, 2, 0]])

>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)

tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]])

只要按照转换规则,就可以计算出函数的结果

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值