torch.index_select 是 PyTorch 中用于从张量特定维度上按索引提取元素的函数。以下从功能、语法、应用案例、常见错误四个方面详细解析:

一、功能与语法
核心功能
从张量的指定维度(轴)上,根据给定的索引列表提取元素,生成新的张量。不改变原张量,返回新分配的内存。
语法
torch.index_select(input, dim, index, out=None) → Tensor
参数
input:输入张量,必须是 PyTorch 张量。dim:整数,表示要索引的维度(轴)。例如:dim=0表示按行索引(对二维张量)。dim=1表示按列索引(对二维张量)。
index:一维张量(必须是LongTensor类型),包含要提取的索引值。out(可选):输出张量,用于存储结果。
二、6个实际应用案例
案例1:从矩阵中选择特定行
import torch
# 创建3×3矩阵
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 选择第0行和第2行
indices = torch.tensor([0, 2]) # 必须是LongTensor
result = torch.index_select(x, dim=0, index=indices)
print(result)
# 输出:
# tensor([[1, 2, 3],
# [7, 8, 9]])
案例2:从批量数据中提取特定样本
在深度学习中,从批量数据(batch_size, channels, height, width)中选择特定样本:
# 假设batch_size=4,每个样本是3通道的28×28图像

最低0.47元/天 解锁文章
2285

被折叠的 条评论
为什么被折叠?



