pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法

数组排序并返回前N值

对数组的第n个维度进项排序,并返回排序的前k个元素的values, indices

torch.topk(input, k, dim=n, largest=True, sorted=True, out=None) 
-> (Tensor, LongTensor)

例:取input的第1维

values, indices = torch.topk(input , 1, dim=1)

l a r g e s t { T r u e ,按照大到小排序 F a l s e ,按照小到大排序 largest\left\{\begin{array}{l}True,\mathrm{按照大到小排序}\\False,\mathrm{按照小到大排序}\end{array}\right. largest{True按照大到小排序False按照小到大排序

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

sorted:返回的结果按照顺序返回

out:可缺省,不要

按照索引取值:torch.gather(input,dim,index),或indicat_select

import torch
input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
index = torch.LongTensor([[3],[2],[4],[0]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, index)
————————————————
版权声明:https://blog.csdn.net/cpluss/article/details/90260550 https://www.zhihu.com/question/374472015

截取Tensor

初始方法

```c
res1 = []
for i in range(10):
    res1.append(i*3)
res = out[:, res1]


## [narrow](https://pytorch.org/docs/stable/generated/torch.narrow.html?highlight=narrow#torch.narrow)维度范围返回 。返回的张量和张量共享相同的底层存储。
Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。

```c
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)# 沿着x的第0维度,的第0位置开始,向下选取2个距离
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)# 沿着x的第1维度,的第1位置开始,向下选取2个距离
tensor([[ 2,  3],
        [ 5,  6],
        [ 8,  9]])

mask方式选取torch.masked_select(input,mask)

>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)

tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],
        [-2.0099,  0.6249, -0.5382,  1.4458],
        [ 0.0684,  0.4118,  0.1011, -0.5684]])

>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)

tensor([[ True,  True, False, False],
        [False,  True, False,  True],
        [False, False, False, False]])

>>> print(torch.masked_select(x, mask))

tensor([1.2001, 1.2968, 0.6249, 1.4458])

————————————————
版权声明  https://pytorch.org/docs/stable/generated/torch.masked_select.html#torch.masked_select   https://cloud.tencent.com/developer/article/1755706

permute置换操作+res = input[:,0:N]

where()

>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)



  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值