gather torch_Pytorch 中的 torch.gather 函数

官方文档

torch.gather(input, dim, index, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # dim=0

out[i][j][k] = input[i][index[i][j][k]][k] # dim=1

out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

Parameters:

input (Tensor) – The source tensor

dim (int) – The axis along which to index

index (LongTensor) – The indices of elements to gather

out (Tensor, optional) – Destination tensor

Example:

>>> t = torch.Tensor([[1,2],[3,4]])

>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))

1 1

4 3

[torch.FloatTensor of size 2x2]

torch.gather 函数用于从参数 t 选择性输出特定 index 的矩阵,输出矩阵的大小跟 index 的大小是一样的,torch.gather 的 dim 参数用来选择 index 作用的 axis。

构建 2×18×2×2 的矩阵 a,

# a[:,i,:,:] = i

a = torch.arange(18).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(2,18,2,2)

a[:,2,:,:]

Out[24]:

tensor([[[ 2., 2.],

[ 2., 2.]],

[[ 2., 2.],

[ 2., 2.]]])

a[:,16,:,:]

Out[25]:

tensor([[[ 16., 16.],

[ 16., 16.]],

[[ 16., 16.],

[ 16., 16.]]])

现在要通过 torch.gather 函数把 a 变成 offset[:,:9,:,:] = [0,2,...16],offset[:,9:,:,:] = [1,3,..,17]

N = 9

offsets_index = Variable(torch.cat([torch.arange(0, 2*N, 2), torch.arange(1, 2*N+1, 2)]), requires_grad=False).long()

offsets_index

Out[29]:

tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5,

7, 9, 11, 13, 15, 17])

offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*a.size())

offset = torch.gather(a, dim=1, index=offsets_index)

offset[:,0,:,:]

Out[34]:

tensor([[[ 0., 0.],

[ 0., 0.]],

[[ 0., 0.],

[ 0., 0.]]])

offset[:,1,:,:]

Out[35]:

tensor([[[ 2., 2.],

[ 2., 2.]],

[[ 2., 2.],

[ 2., 2.]]])

offset[:,8,:,:]

Out[39]:

tensor([[[ 16., 16.],

[ 16., 16.]],

[[ 16., 16.],

[ 16., 16.]]])

offset[:,9,:,:]

Out[40]:

tensor([[[ 1., 1.],

[ 1., 1.]],

[[ 1., 1.],

[ 1., 1.]]])

offset[:,17,:,:]

Out[41]:

tensor([[[ 17., 17.],

[ 17., 17.]],

[[ 17., 17.],

[ 17., 17.]]])

代码帖的有点多,主要是为了验证效果。

offset 的输出规则如下:

offset[i][j][k][s] = input[i][offsets_index[i][j][k][s]][k][s] # dim=1

因为 dim = 1,offsets_index 影响 axis = 1 的维度,offset[i][j][k][s] 由 input 根据 offsets_index 在 axis=1 维度用 offsets_index[i][j][k][s] 作为索引,其他的位置不变,同理其他维度改变就用 index[i][j][k][s] 作为对应 axis 的索引 。

最终输出 offset 的时候,offset[:][1][:][:] 的数据只是选择了 input 在 axis=1 上 input[:][2][:][:] 的所有数据,在第 axis=0,2,3 维度 input 的索引和 offset 是对应的,所以 offset 在相应位置上的数据和 input 一样。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值