torch.gather()

torch.gather(input, dim, index) → Tensor

Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d). out will have the same shape as index.

  • input (Tensor) :the source tensor
  • dim (int) :the axis along which to index
  • index (LongTensor) :the indices of elements to gather

Example:

# x为RNN的输出,output_dims为每个时间步的输出维度,需要按照index从中取出特定索引的预测值
# [batch_size, seq_len, output_dims]
x = torch.tensor([[[0.1, 0.7, 0.1, 0.1], [0.2, 0.6, 0.1, 0.1],[0.1, 0.5, 0.2, 0.2]],
                  [[0.2, 0.4, 0.2, 0.2], [0.3, 0.3, 0.2, 0.2],[0.1, 0.4, 0.2, 0.3]],
                  [[0.1, 0.6, 0.1, 0.2], [0.5, 0.3, 0.1, 0.1],[0.1, 0.1, 0.2, 0.6]]]) 
# [batch_size, seq_len]
index = torch.tensor([[0, 1, 2], [2, 1, 1], [0, 0, 0]])

index = index.unsqueeze(-1)   # 增加一个维度
pred = torch.gather(x, dim=2, index=index)

# 输出结果pred
tensor([[[0.1],[0.6],[0.2]],
        [[0.2],[0.3],[0.4]],
        [[0.1],[0.5],[0.1]]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值