高级索引(Advanced Indexing)是 NumPy 和 PyTorch 中提供的一种强大且灵活的索引机制,它允许使用多个数组作为索引来访问张量或数组的特定元素。高级索引使得在多维数组中选择和操作特定元素变得更加直观和简便。
基本概念
- 基本索引(Basic Indexing):
- 使用整数、切片、布尔数组等进行索引。
- 索引操作与 Python 的列表索引非常相似。
- 高级索引(Advanced Indexing):
- 使用整数数组或张量进行索引。
- 返回的是一个新数组或张量,维度可能会根据索引的使用方式而改变。
示例
示例 1:使用整数数组进行索引
import numpy as np
# 创建一个 3x4 的数组
x = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 使用高级索引获取特定元素
rows = np.array([0, 1, 2])
cols = np.array([1, 2, 3])
print(x[rows, cols]) # 输出 [2 7 12]
在这个示例中:
x
是一个形状为(3, 4)
的二维数组。rows = np.array([0, 1, 2])
指定要从每一行中选择的行索引。cols = np.array([1, 2, 3])
指定要从每一行中选择的列索引。x[rows, cols]
使用高级索引返回的是x[0, 1]
,x[1, 2]
和x[2, 3]
对应的元素,结果是[2, 7, 12]
。
示例 2:使用 PyTorch 进行高级索引
import torch
# 创建一个 2x3 的张量
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.tensor([0, 2])
# 使用高级索引获取特定元素
selected_values = y_hat[[0, 1], y]
print(selected_values) # 输出 tensor([0.1000, 0.5000])
在这个示例中:
y_hat
是一个形状为(2, 3)
的二维张量。y = torch.tensor([0, 2])
指定要选择的列索引。y_hat[[0, 1], y]
使用高级索引返回的是y_hat[0, 0]
和y_hat[1, 2]
对应的元素,结果是[0.1, 0.5]
。
高级索引的行为
-
维度变化:
- 高级索引返回的新数组或张量的形状与索引数组的形状有关。
- 如果索引数组的形状为
(n,)
,则返回数组的形状为(n,)
。 - 如果索引数组的形状为
(m, n)
,则返回数组的形状为(m, n)
。
-
广播机制:
- 高级索引可以结合广播机制来选择特定元素。
- 如果索引数组的维度不同,NumPy 或 PyTorch 会自动进行广播,使得索引操作可以在不同维度上进行。
示例 3:不同形状的索引数组
# 创建一个 3x4 的数组
x = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 使用形状为 (2,) 的索引数组
rows = np.array([0, 2])
cols = np.array([1, 3])
print(x[rows[:, np.newaxis], cols]) # 输出 [[ 2 4]
# [10 12]]
在这个示例中:
-
选择所有元素:
rows[:]
选择rows
的所有元素,结果仍然是[0, 2]
,形状为(2,)
。 -
增加新维度:
rows[:, np.newaxis]
在原始数组的第二个位置增加一个新维度。结果是一个二维数组,每个元素都变成一个单独的一行,形状为(2, 1)
。 rows[:, np.newaxis]
将rows
变为形状为(2, 1)
的数组。cols
的形状为(2,)
。- 通过广播,索引数组变为
(2, 2)
,分别对应x[0, 1]
,x[0, 3]
,x[2, 1]
和x[2, 3]
,结果是[[2, 4], [10, 12]]
。
总结
高级索引允许我们使用多个数组或张量来选择特定元素。它提供了比基本索引更灵活和强大的功能,特别适用于多维数组或张量中的复杂选择操作。高级索引在数据处理、机器学习和科学计算中非常有用。