torch.where的用法为
torch.where(condition,x,y)
例子
A = torch.Tensor([-1,2,3])
B = torch.ones(3)
torch.where(A>1,A,B)
输出为:
Tensor([1,2,3])
torch.where的用法为
torch.where(condition,x,y)
例子
A = torch.Tensor([-1,2,3])
B = torch.ones(3)
torch.where(A>1,A,B)
输出为:
Tensor([1,2,3])