PyTorch 中的【高级索引】 或 【花式索引】

在 PyTorch 中,“高级索引” 或 “花式索引” 允许用户使用整数数组或布尔数组进行索引,从而获取张量中的指定元素。

具体规则:

  1. 整数数组索引:使用一个整数数组作为索引,可以获取张量中指定位置的元素。

  2. 布尔数组索引:使用一个布尔数组作为索引,可以根据布尔数组的 True/False 值获取张量中对应位置的元素。

举例说明:

整数数组索引:
import torch

# 创建一个张量
tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])

# 使用整数数组索引获取指定位置的元素
indices = torch.tensor([0, 2])  # 指定要获取的行索引
result = tensor[indices]  # 获取指定行的元素
print(result)

输出:

tensor([[1, 2],
        [5, 6]])
布尔数组索引:
import torch

# 创建一个张量
tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])

# 创建一个布尔数组,用于选择元素
mask = torch.tensor([True, False, True])  # 选择第1和第3行的元素
result = tensor[mask]  # 根据布尔数组选择元素
print(result)

输出:

tensor([[1, 2],
        [5, 6]])

这些是使用 PyTorch 中的高级索引进行元素选择的基本规则和示例。

多维数组索引1
import torch

# 创建一个示例张量
tensor = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]],
    [[13, 14, 15], [16, 17, 18]]
])

# 创建两个索引数组,用于选择元素
indices1 = torch.tensor([[0, 1], [2, 0]])  # 选择第0个维度(行)的索引
indices2 = torch.tensor([[1, 2], [0, 2]])  # 选择第1个维度(列)的索引

# 使用多维数据索引选择元素
result = tensor[indices1, indices2]  # 根据索引数组选择元素
print(result)

输出:

tensor([[[ 4,  5,  6],
         [ 7,  8,  9]],

        [[13, 14, 15],
         [ 4,  5,  6]]])
多维数组索引2
lq = 4
lk = 5
k = torch.rand((3,lq,lk))
index1 = torch.randint(lq,(1,2*lk))
index2 = torch.randint(lk,(lq,2*lk))
print(index1.shape,index2.shape)
print(k.shape,k[:, index1, index2].shape)
print('index1:\n',index1)
print('index2:\n',index2)
print('k:\n',k)
print('index k:\n',k[:,index1,index2])

其中,index1和index2可以互相broadcasting,并且index1和index2broadcasting的shape决定了索引结果的shape。

输出为:

torch.Size([1, 10]) torch.Size([4, 10])
torch.Size([3, 4, 5]) torch.Size([3, 4, 10])
index1:
 tensor([[0, 1, 2, 3, 3, 0, 3, 0, 2, 3]])
index2:
 tensor([[1, 1, 0, 0, 0, 3, 4, 1, 4, 4],
        [3, 2, 2, 2, 1, 3, 4, 1, 4, 1],
        [1, 1, 3, 0, 2, 2, 0, 0, 1, 4],
        [4, 3, 3, 0, 4, 0, 3, 3, 3, 2]])
k:
 tensor([[[0.8977, 0.2755, 0.6238, 0.4197, 0.9837],
         [0.1231, 0.4493, 0.2263, 0.9272, 0.6347],
         [0.3343, 0.3117, 0.6854, 0.2295, 0.6499],
         [0.8584, 0.3650, 0.2476, 0.6275, 0.4702]],

        [[0.8483, 0.6208, 0.6188, 0.4867, 0.1121],
         [0.7733, 0.3900, 0.5515, 0.8151, 0.0637],
         [0.3329, 0.7633, 0.1499, 0.2026, 0.0895],
         [0.0793, 0.1707, 0.5915, 0.1170, 0.2679]],

        [[0.4250, 0.5561, 0.2284, 0.8940, 0.1764],
         [0.8897, 0.2199, 0.1317, 0.6584, 0.7289],
         [0.3934, 0.3325, 0.7833, 0.7059, 0.7230],
         [0.4195, 0.0095, 0.9322, 0.5098, 0.5191]]])
index k:
 tensor([[[0.2755, 0.4493, 0.3343, 0.8584, 0.8584, 0.4197, 0.4702, 0.2755,
          0.6499, 0.4702],
         [0.4197, 0.2263, 0.6854, 0.2476, 0.3650, 0.4197, 0.4702, 0.2755,
          0.6499, 0.3650],
         [0.2755, 0.4493, 0.2295, 0.8584, 0.2476, 0.6238, 0.8584, 0.8977,
          0.3117, 0.4702],
         [0.9837, 0.9272, 0.2295, 0.8584, 0.4702, 0.8977, 0.6275, 0.4197,
          0.2295, 0.2476]],

        [[0.6208, 0.3900, 0.3329, 0.0793, 0.0793, 0.4867, 0.2679, 0.6208,
          0.0895, 0.2679],
         [0.4867, 0.5515, 0.1499, 0.5915, 0.1707, 0.4867, 0.2679, 0.6208,
          0.0895, 0.1707],
         [0.6208, 0.3900, 0.2026, 0.0793, 0.5915, 0.6188, 0.0793, 0.8483,
          0.7633, 0.2679],
         [0.1121, 0.8151, 0.2026, 0.0793, 0.2679, 0.8483, 0.1170, 0.4867,
          0.2026, 0.5915]],

        [[0.5561, 0.2199, 0.3934, 0.4195, 0.4195, 0.8940, 0.5191, 0.5561,
          0.7230, 0.5191],
         [0.8940, 0.1317, 0.7833, 0.9322, 0.0095, 0.8940, 0.5191, 0.5561,
          0.7230, 0.0095],
         [0.5561, 0.2199, 0.7059, 0.4195, 0.9322, 0.2284, 0.4195, 0.4250,
          0.3325, 0.5191],
         [0.1764, 0.6584, 0.7059, 0.4195, 0.5191, 0.4250, 0.5098, 0.8940,
          0.7059, 0.9322]]])

再比如这个例子:

Q = torch.rand((B,H,L_Q,D))
M_top = torch.randint(L_Q,(B,H,sample_Q))
Q_reduce = Q[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   M_top, :]  # (B,H,sample_Q)
Q_reduce = Q[torch.arrange(B).unsqueeze(-1).unsqueeze(-1),torch.arrange(H).unsqueeze(0).unsqueeze(-1),M_top,:]  # (B,H,sample_Q)

两种Q_reduce的写法是一致的。

  • 7
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Kiki酱。

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值