介绍
torch.where
是 PyTorch 中的一个函数,用于根据给定条件从两个张量中选择元素并返回一个新的张量。它的使用方式如下:
torch.where(condition, x, y)
参数解释:
condition
: 一个布尔型的张量,用于指定选择元素的条件。当条件为 True 时,选择x
中对应位置的元素,否则选择y
中对应位置的元素。x
: 一个张量,表示条件为 True 时要选择的元素。y
: 一个张量,表示条件为 False 时要选择的元素。
返回值:
- 返回一个新的张量,根据条件选择
x
或y
中的元素。
示例代码
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
condition = torch.tensor([True, False, True])
result = torch.where(condition, x, y)
print(result)
输出结果:
tensor([1, 5, 3])
在上述示例中,根据 condition
中的值,选择 x
或 y
中对应位置的元素,最终返回一个新的张量。在这个例子中,结果为 [1, 5, 3]
,因为 condition
中的第一和第三个值为 True,所以选择 x
中的对应位置的值,即 1
和 3
,而 condition
中的第二个值为 False,所以选择 y
中的对应位置的值,即 5
。