tensor按索引批量操作(torch.gather torch.scatter torch.scatter_reduce)

torch.gather是把tensor A的值基于dim顺序,根据index取出来;

torch.scatter是把tensor A的值基于dim顺序,根据index替换为src中的值;

torch.scatter_reduce是把tensor A的值基于dim顺序,根据index取出后,与src对应的值做reduce聚合。(注意:torch.scatter_reduce在torch>=1.13才有,否则建议使用torch_scatter包里的scatter函数)

举例torch.scatter:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

更新方式:
self[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2

代码块中等号左边:

dim=?表示在self第?维度上取index的值,self其他维度取index所在的索引对应的值;

代码块中等号右边:

赋值的值,来自src。index中每个值所在的位置,对应src所在的位置。(因此要求index.shape <= src.shape)src取相应的值与index的值无关,只与index的位置(索引)有关。

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

'''
dim=0,在self第0维度对应的索引放index的值,其他维度的索引放index对应的索引;
src各维度放index对应的索引。

步骤:
index第一个元素:
dim=0,在self第0维度对应的索引放index的值,即0,得到:
self[0];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引),得到:
self[0][0];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][0],得到:
self[0][0]=src[0][0]=1;


index第二个元素:
dim=0,在self第0维度对应的索引放index的值,即1,得到:
self[1];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引,该索引为1),得到:
self[1][1];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][1],得到:
self[1][1]=src[0][1]=2;

...

index第四个元素:
dim=0,在self第0维度对应的索引放index的值,即0,得到:
self[0];
其他维度的索引放index对应的索引,(dim=0有索引了,只差dim=1的索引,该索引为4),得到:
self[0][4];
赋予的值是 src各维度放index对应的索引,index第四个元素所在位置为[0][4],得到:
self[0][4]=src[0][4]=4;


'''

第二个例子

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])
'''
dim=1,在self第1维度对应的索引放index的值,其他维度的索引放index对应的索引;
src各维度放index对应的索引。

步骤:
index第一个元素:
dim=1,在self第1维度对应的索引放index的值,即0,得到:
self[?][0];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引),得到:
self[0][0];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][0],得到:
self[0][0]=src[0][0]=1;

...

index第三个元素:
dim=1,在self第1维度对应的索引放index的值,即2,得到:
self[?][2];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引,该索引为0,因为在index第0行),得到:
self[0][1];
赋予的值是 src各维度放index对应的索引,index第一个元素所在位置为[0][1],得到:
self[0][1]=src[0][1]=2;

...

index第六个元素:
dim=1,在self第1维度对应的索引放index的值,即4,得到:
self[0][4];
其他维度的索引放index对应的索引,(dim=1有索引了,只差dim=0的索引,该索引为1, 因为在index第1行),得到:
self[1][4];
赋予的值是 src各维度放index对应的索引,index第四个元素所在位置为[0][4],得到:
self[1][4]=src[1][2]=8;

'''

题外话:脑子有点木,看半天绕不过来,感谢lhb大神的讲解。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值