之前使用的时候被网上一些错误的解释误导,通过自己验证.
1.概述
torch.where()函数的作用是按照一定的规则按照元素合并两个tensor类型。
torch.where(condition,a,b):
- 其中a,b均为tensor,且为同类型。
- 输入参数condition:限制条件,如果满足条件,则选择a,否则选择b作为输出。
2.例子
import torch
a=torch.tensor([[-0.3, 0.2, 5.0],
[-0.9, 1.0, 2.0]])
b=torch.tensor([[1.3, -0.2, 1.0],
[1.59, 3.0, 4.0]])
c = torch.where(a>0, a,b)
c
输出为:tensor([[1.3000, 0.2000, 5.0000], [1.5900, 1.0000, 2.0000]])