1.where
import torch
cond=torch.tensor([[0.6,0.4],[0.7,0.3]])
a=torch.tensor([[0.,0.],[0.,0.]])
b=torch.tensor([[1.,1.],[1.,1.]])
c=torch.where(cond>0.5,a,b)
print(a)
print(b)
print(c)
结果:tensor([[0., 0.],
[0., 0.]])
tensor([[1., 1.],
[1., 1.]])
tensor([[0., 1.],
[0., 1.]])
2.gather