索引
import torch
a = torch.rand(4,3,28, 28)
print("a.shape: ", a.shape)
print("a[0].shape: ", a[0].shape)
print("a[0][0].shape: ", a[0][0].shape)
print("a[0, 0].shape: ", a[0, 0].shape)
print("a[0, 1, 2, 3]: ", a[0, 1, 2, 3])
取前N个或后N个
import torch
a = torch.rand(4,3,28, 28)
print("a.shape: ", a.shape)
print("a[:2].shape: ", a[:2].shape)
print("a[:2, :1, :, :].shape: ", a[:2, :1, :, :].shape)
print("a[:2, 1:, :, :].shape: ", a[:2, 1:, :, :].shape)
print("a[:2, -1:, :, :].shape: ", a[:2, -1:, :, :].shape)
按步长选择
import torch
a = torch.rand(4,3,28, 28)
print("a[:, :, 0:20:2, 0:20:2].shape: ", a[:, :, 0:20:2, 0:20:2].shape)
print("a[:, :, ::2, ::2].shape: ", a[:, :, ::2, ::2].shape)
b = torch.arange(0, 10)
print("b[::2]", b[::2])
自定义选取
import torch
a = torch.rand(4,3,28, 28)
print("a.index_select(0, torch.tensor([1,2])).shape: ",
a.index_select(0, torch.tensor([1,2])).shape)
print("a.index_select(1, torch.tensor([1])).shape: ",
a.index_select(1, torch.tensor([1])).shape)
”…“符号
import torch
a = torch.rand(4,3,28, 28)
print("a[...].shape: ", a[...].shape)
print("a[0, ...].shape: ", a[0, ...].shape)
print("a[:, 1, ...].shape: ", a[:, 1, ...].shape)
print("a[..., :2].shape: ", a[..., :2].shape)
mask选取
import torch
a = torch.rand(3, 4)
print("a:\n", a)
mask = a.ge(0.5)
print("mask:\n", mask)
print("torch.masked_select(a, mask): ", torch.masked_select(a, mask))
print("torch.masked_select(a, mask).shape: ", torch.masked_select(a, mask).shape)
take选取
import torch
a = torch.rand(3, 4)
print("a:\n", a)
print("torch.take(a, torch.tensor([0, 2]):"
, torch.take(a, torch.tensor([0, 2, 11])))