1、Tensor支持基本索引和切片操作。
例:
(1)基本索引
输出为:
(2)切片
输出:
(3)带步长的切片(PyTorch现在不支持负步长)
输出:
2、Tensor支持ndarray中的高级索引(整数索引和布尔索引)操作。
例:
(1)整数索引
输出:
(2)布尔索引
输出:
3、torch.nonzero用于返回非零值的索引矩阵。
例:
(1)
输出:
(2)
输出:(不是唯一)
4、torch.where(condition,x,y)判断condition的条件是否满足。
当某个元素满足条件时,则返回对应矩阵x相同位置的元素,否则返回矩阵y的元素。
例:
输出:(不是唯一)
5、代码:
import torch
a = torch.arange(9).view(3,3)
#基本索引
print(a[2,2])
#切片
print(a[1:,:-1])
#带步长的切片(PyTorch现在不支持负步长)
print(a[::2])
#整数索引
rows = [0,1]
cols = [2,2]
print(a[rows,cols])
#布尔索引
index = a>4
print(index)
print(a[index])
a = torch.arange(9).view(3,3)
index = torch.nonzero(a >= 8)
print(index)
a = torch.randint(0,2,(3,3))
print(a)
index = torch.nonzero(a)
print(index)
x = torch.randn(3,2)
y = torch.ones(3,2)
print(x)
print(torch.where(x>0,x,y))