在 PyTorch 中,类似 input[range(target.shape[0]), target]
这样的表达式通常用于获取输入张量(input
)中特定位置的值,其中 位置由 target
张量指定的。 首先,range(target.shape[0])
它创建了一个从 0
到 target
张量中第一个维度的大小计算得出的整数范围。例如,如果 target
张量的形状为 (5, 3)
,则 range(target.shape[0])
返回一个大小为 5
的整数范围。然后,这个整数范围用于指定要获取的输入中的特定行, 即 input[range(target.shape[0])]
会返回一个张量(input
)与 target
张量第一维大小一样, 且包含输入张量(input
)中所指定行所有的列。 接下来,使用 target
张量指定列,从而选取每行中的所需位置。input[range(target.shape[0]), target]
最终会返回一个大小为 target.shape
的张量,包含了输入张量中所有位置需要的值。
举例来说, 如果 input
张量的形状为 (5, 4)
, target
张量为 (5)
, 并且具有以下值:
input = Tensor([[ 764, 4, 67, 785],
[ 311, 101, 911, 199],
[ 759, 362, 215, 651],
[ 471, 821, 738, 875],
[ 109, 828, 994, 675]])
target = Tensor([1, 0, 2, 3, 2])
那么 input[range(target.shape[0]), target]
将返回具有以下值得张量:
Tensor([ 4, 311, 215, 875, 994])
其中每个值都是得到所需位置的结果。