where
torch.where(condition,x,y) #condition必须是tensor类型
condition的维度和x,y一致,用1和0分别表示该位置的取值
例:输入:
cond = torch.tensor([[0.6, 0.7],
[0.3, 0.6]])
a = torch.tensor([[1., 1.],
[1., 1.]])
b = torch.tensor([[0., 0.],
[0., 0.]])
c = torch.where(cond > 0.5, a, b) #此时cond只有0和1的值
print(c)
输出:
tensor([[1., 1.],
[0., 1.]])
高度并行
gather
torch.gather(input, dim, index, out=None)
相当于查表操作
举例:
prob = torch.randn(4, 10)
idx = prob.topk(dim=1, k=3) # prob在维度1中前三个最大的数,一共有4行,返回值和对应的下标
print("all of topk idx: ", idx)
idx