Pytorch的gather()和scatter()
1.gather()
gather是取的意思,意为把某一tensor矩阵按照一个索引序列index取出,组成一个新的矩阵。
gather(input,dim,index)
参数:
- input是要取值的矩阵
- dim指操作的维度,0为竖向操作即按行操作,1为横向操作即按列操作
- index为索引序列
下面这个例子是按行取出第一行的’0号元素’,'0行元素’组成新的第一行;
再取出第二行的‘1号元素’,‘0号元素’组成新的第二行
a = torch.Tensor([[1,2],[3,4]])
b = torch.gather(a, 1, torch.LongTensor([[0,0],[1,0]]))
print(a)
1 2
3 4
print(b)
1 1
4 3
2.scatter_()
这个是‘放’的意思,即把原tensor矩阵的元素按照新索引index的序列位置,放到新的矩阵中。
scatter_(dim,index,src)
参数:
- src 是要取出元素的矩阵
注意要放置的矩阵不在参数中,其直接调用这个函数。
下例就是按索引[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]位置,把随机矩阵a中元素放置到全0矩阵torch.zeros(3,5)中。
a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b)
其中dim=0
3.参考:
https://zhuanlan.zhihu.com/p/59346637