PyTorch 的input[range(target.shape[0]), target] 表达式

在PyTorch中,`input[range(target.shape[0]),target]`用于按目标张量`target`指定的位置选取输入张量`input`的值。此表达式首先创建一个整数范围,然后选取输入张量的特定行和列,返回一个与`target`形状相同的张量。例如,给定`input`和`target`张量,它会返回一个新的张量,其中包含`input`中对应位置的值。
摘要由CSDN通过智能技术生成

在 PyTorch 中,类似 input[range(target.shape[0]), target] 这样的表达式通常用于获取输入张量(input)中特定位置的值,其中 位置由 target 张量指定的。 首先,range(target.shape[0]) 它创建了一个从 0 到 target 张量中第一个维度的大小计算得出的整数范围。例如,如果 target 张量的形状为 (5, 3),则 range(target.shape[0]) 返回一个大小为 5 的整数范围。然后,这个整数范围用于指定要获取的输入中的特定行, 即 input[range(target.shape[0])] 会返回一个张量(input)与 target 张量第一维大小一样, 且包含输入张量(input)中所指定行所有的列。 接下来,使用 target 张量指定列,从而选取每行中的所需位置。input[range(target.shape[0]), target] 最终会返回一个大小为 target.shape 的张量,包含了输入张量中所有位置需要的值。

举例来说, 如果 input 张量的形状为 (5, 4), target 张量为 (5), 并且具有以下值:

input = Tensor([[ 764,    4,   67,  785],
                 [ 311,  101,  911,  199],
                 [ 759,  362,  215,  651],
                 [ 471,  821,  738,  875],
                 [ 109,  828,  994,  675]])
target = Tensor([1, 0, 2, 3, 2])

那么 input[range(target.shape[0]), target] 将返回具有以下值得张量:

Tensor([  4, 311, 215, 875, 994])

其中每个值都是得到所需位置的结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值