Pytorch 基础之张量索引

文章详细介绍了Tensor张量在PyTorch中的索引与切片方法,包括index索引、selectfirst/lastN、selectbysteps、index_select、masked_select和take等操作,同时提到了索引越界的问题和使用时的注意事项。
摘要由CSDN通过智能技术生成

本次将介绍一下 Tensor 张量常用的索引与切片的方法:

1. index 索引

index 索引值表示相应维度值的对应索引

a = torch.rand(4, 3, 28, 28)
print(a[0].shape)             # 返回维度一的第 0 索引 tensor
print(a[0, 0].shape)          # 返回维度一 0 索引位置,维度二 0 索引位置的 tensor
print(a[0, 0, 0].shape)     # 返回维度一 0 索引,维度二 0 索引,维度三 0索引的 tensor
print(a[0, 0, 2, 4].shape)  # 返回维度一 0 索引,维度二 0 索引,维度三 2索引,维度四 4索引位置的 tensor (dim = 0)
print(a[0, 0, 2, 4])

# 输出结果
torch.Size([3, 28, 28])
torch.Size([28, 28])
torch.Size([28])
torch.Size([])
tensor(0.4504)
2. select first/last N

返回前 N 个或后 N 个的 tensor

【:】表示该维度所有值;

【:2】表示从索引 0 开始到索引 2 的值,包首不包尾

【1:】表示索引 1 开始到最后

【-2:】表示倒数第二个值到最后

【…】表示一个或几个维度不变

a = torch.rand(4, 3, 28, 28)
print(a[:2].shape)    # 返回维度一索引 0 ~ 2 的 tensor,相当于 a[:2, :, :, :].shape, : 表示都选择
print(a[:2, :1, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 0 ~ 1 的 tensor
print(a[:2, :1, :3, :4].shape) # 返回维度一索引 0 ~ 2,维度二索引 0 ~ 1,维度三索引 0 ~ 3,维度四索引 0 ~ 4 的 tensor
print(a[:2, 1:, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 1 ~ 3 的 tensor
print(a[:2, -2:, :, :].shape) # 返回维度一索引 0 ~ 2,维度二索引 1 ~ 3 的 tensor

# ---------【...】的应用 --------------
print(a[...].shape)       # 表示返回一样的 a
print(a[0, ...].shape)    # 表示返回维度一,索引 0 位置的 tensor
print(a[:, 1, ...].shape) # 表示返回维度二,索引 1 位置的 tensor
print(a[:, :, 2, ...].shape)   # 表示返回维度三,索引 2 位置的 tensor
print(a[..., 10].shape)    # 表示返回维度四,索引 10 位置的 tensor
print(a[..., :2].shape)    # 表示返回维度四,索引 0 ~2 数量的 tensor


# 输出结果
torch.Size([2, 3, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 3, 4])
torch.Size([2, 2, 28, 28])
torch.Size([2, 2, 28, 28])

# ---------【...】的应用的输出结果 --------------
torch.Size([4, 3, 28, 28])
torch.Size([3, 28, 28])
torch.Size([4, 28, 28])
torch.Size([4, 3, 28])
torch.Size([4, 3, 28])
torch.Size([4, 3, 28, 2])
3. select by steps

按一定的间隔 steps 返回 tensor

【0:28:2】表示从索引 0 开始到 28,间隔 2 取数,所以为 14

有二个冒号,便是按一定间隔取

a = torch.rand(4, 3, 28, 28)
print(a[:, :, 0:28:2, 0:28:4].shape)

#输出结果
torch.Size([4, 3, 14, 7])
4. index_select(intputTensor, dim, indexTensor)

根据输入的 inputTensor ,按指定的维度索引 dim,返回与 indexTensor 一样的 size,其它维度不变的新 tensor

a = torch.rand(4, 3, 28, 28)
b = a.index_select(2, torch.arange(8))    # 也可以 inputTensor 直接调用
c = torch.index_select(a, 2, torch.arange(8)) # 建议用这种形式,返回 a 第 3 个维度与 torch.arange(8)一样 size ,其它维度不变的新 tensor 
print(b.shape)
print(c.shape)

# 输出结果
torch.Size([4, 3, 8, 28])
torch.Size([4, 3, 8, 28])

5. masked_select(intputTensor, maskTensor)

返回一个满足 maskTensor 条件的一维 tensor

a = torch.rand(3, 4)
print(a)
x = a.ge(0.5)     # 大于 0.5 的 bool 张量
print(x)
print(a.masked_select(x))    # 返回值大于 0.5 的一维张量
print(torch.masked_select(a, x))    # 和上面一样,但建议用这种形式

# 输出结果
tensor([[0.0169, 0.1965, 0.7381, 0.9250],
        [0.8292, 0.2519, 0.1531, 0.8987],
        [0.1365, 0.4650, 0.4005, 0.7589]])
tensor([[False, False,  True,  True],
        [ True, False, False,  True],
        [False, False, False,  True]])
tensor([0.7381, 0.9250, 0.8292, 0.8987, 0.7589])
tensor([0.7381, 0.9250, 0.8292, 0.8987, 0.7589])

6. take(inputTensor, indexTensor)

根据一维的索引张量 indexTensor,返回一个新的一维 tensor,inputTensor 看成是一维的。

a = src = torch.tensor([[4, 3, 5],
                       [6, 7, 8]])
print(a.size())
b = torch.tensor([0, 2, 5])   # 如 0 --> 4, 2 --> 5, 5 --> 8
c = torch.take(a, b)
print(c)
print(c.size())

# 输出结果
torch.Size([2, 3])
tensor([4, 5, 8])
torch.Size([3])

总结:涉及到索引,就会存在索引越界的常见问题(如下所示),在使用的时候要注意一下。

IndexError: index 29 is out of bounds for dimension 1 with size 28

有不足之处,欢迎一起交流学习!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值