本次将介绍一下 Tensor 张量常用的索引与切片的方法:
1. index 索引
index 索引值表示相应维度值的对应索引
a = torch.rand(4, 3, 28, 28)
print(a[0].shape) # 返回维度一的第 0 索引 tensor
print(a[0, 0].shape) # 返回维度一 0 索引位置,维度二 0 索引位置的 tensor
print(a[0, 0, 0].shape) # 返回维度一 0 索引,维度二 0 索引,维度三 0索引的 tensor
print(a[0, 0, 2, 4].shape) # 返回维度一 0 索引,维度二 0 索引,维度三 2索引,维度四 4索引位置的 tensor (dim = 0)
print(a[0, 0, 2, 4])
# 输出结果
torch.Size([3, 28, 28])
torch.Size([28, 28])
torch.Size([28])
torch.Size([])
tensor(0.4504)
2. select first/last N
返回前 N 个或后 N 个的 tensor
【:】表示该维度所有值;
【:2】表示从索引 0 开始到索引 2 的值,包首不包尾
【1:】表示索引 1 开始到最后
【-2:】表示倒数第二个值到最后
【…】表示一个或几个维度不变
a = torch.rand(4, 3, 28, 28)
print(a[:2].shape) # 返回维度一索引 0 ~ 2 的 tensor,相当于 a[:2, :, :, :].shape, : 表示都选择
print(a[:2, :1, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 0 ~ 1 的 tensor
print(a[:2, :1, :3, :4].shape) # 返回维度一索引 0 ~ 2,维度二索引 0 ~ 1,维度三索引 0 ~ 3,维度四索引 0 ~ 4 的 tensor
print(a[:2, 1:, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 1 ~ 3 的 tensor
print(a[:2, -2:, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 1 ~ 3 的 tensor
# ---------【...】的应用 --------------
print(a[...].shape) # 表示返回一样的 a
print(a[0, ...].shape) # 表示返回维度一,索引 0 位置的 tensor
print(a[:, 1, ...].shape) # 表示返回维度二,索引 1 位置的 tensor
print(a[:, :, 2, ...].shape) # 表示返回维度三,索引 2 位置的 tensor
print(a[..., 10].shape) # 表示返回维度四,索引 10 位置的 tensor
print(a[..., :2].shape) # 表示返回维度四,索引 0 ~2 数量的 tensor
# 输出结果
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 3, 4])
torch.Size([2, 2, 28, 28])
torch.Size([2, 2, 28, 28])
# ---------【...】的应用的输出结果 --------------
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 3, 28])
torch.Size([4, 3, 28])
torch.Size([4, 3, 28, 2])
3. select by steps
按一定的间隔 steps 返回 tensor
【0:28:2】表示从索引 0 开始到 28,间隔 2 取数,所以为 14
有二个冒号,便是按一定间隔取
a = torch.rand(4, 3, 28, 28)
print(a[:, :, 0:28:2, 0:28:4].shape)
#输出结果
torch.Size([4, 3, 14, 7])
4. index_select(intputTensor, dim, indexTensor)
根据输入的 inputTensor ,按指定的维度索引 dim,返回与 indexTensor 一样的 size,其它维度不变的新 tensor
a = torch.rand(4, 3, 28, 28)
b = a.index_select(2, torch.arange(8)) # 也可以 inputTensor 直接调用
c = torch.index_select(a, 2, torch.arange(8)) # 建议用这种形式,返回 a 第 3 个维度与 torch.arange(8)一样 size ,其它维度不变的新 tensor
print(b.shape)
print(c.shape)
# 输出结果
torch.Size([4, 3, 8, 28])
torch.Size([4, 3, 8, 28])
5. masked_select(intputTensor, maskTensor)
返回一个满足 maskTensor 条件的一维 tensor
a = torch.rand(3, 4)
print(a)
x = a.ge(0.5) # 大于 0.5 的 bool 张量
print(x)
print(a.masked_select(x)) # 返回值大于 0.5 的一维张量
print(torch.masked_select(a, x)) # 和上面一样,但建议用这种形式
# 输出结果
tensor([[0.0169, 0.1965, 0.7381, 0.9250],
[0.8292, 0.2519, 0.1531, 0.8987],
[0.1365, 0.4650, 0.4005, 0.7589]])
tensor([[False, False, True, True],
[ True, False, False, True],
[False, False, False, True]])
tensor([0.7381, 0.9250, 0.8292, 0.8987, 0.7589])
tensor([0.7381, 0.9250, 0.8292, 0.8987, 0.7589])
6. take(inputTensor, indexTensor)
根据一维的索引张量 indexTensor,返回一个新的一维 tensor,inputTensor 看成是一维的。
a = src = torch.tensor([[4, 3, 5],
[6, 7, 8]])
print(a.size())
b = torch.tensor([0, 2, 5]) # 如 0 --> 4, 2 --> 5, 5 --> 8
c = torch.take(a, b)
print(c)
print(c.size())
# 输出结果
torch.Size([2, 3])
tensor([4, 5, 8])
torch.Size([3])
总结:涉及到索引,就会存在索引越界的常见问题(如下所示),在使用的时候要注意一下。
IndexError: index 29 is out of bounds for dimension 1 with size 28
有不足之处,欢迎一起交流学习!