Pytorch的gather()和scatter()

Pytorch的gather()和scatter()

1.gather()

gather是取的意思,意为把某一tensor矩阵按照一个索引序列index取出,组成一个新的矩阵。

gather(input,dim,index)
参数:

  • input是要取值的矩阵
  • dim指操作的维度,0为竖向操作即按行操作,1为横向操作即按列操作
  • index为索引序列

下面这个例子是按行取出第一行的’0号元素’,'0行元素’组成新的第一行;
再取出第二行的‘1号元素’,‘0号元素’组成新的第二行

a = torch.Tensor([[1,2],[3,4]])
b = torch.gather(a, 1, torch.LongTensor([[0,0],[1,0]]))
print(a)
1 2 
3 4
print(b)
1 1
4 3

2.scatter_()

这个是‘放’的意思,即把原tensor矩阵的元素按照新索引index的序列位置,放到新的矩阵中。

scatter_(dim,index,src)
参数:

  • src 是要取出元素的矩阵

注意要放置的矩阵不在参数中,其直接调用这个函数。

下例就是按索引[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]位置,把随机矩阵a中元素放置到全0矩阵torch.zeros(3,5)中。

a = torch.rand(2, 5)
print(a)
b = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), a)
print(b) 

其中dim=0
dim=0

3.参考:

https://zhuanlan.zhihu.com/p/59346637

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值