torch.where(a , b , c)
三个参数意义如下:
a:判断条件
b:若满足条件,则取x中元素
c:若不满足条件,则取y中元素
需要注意的是a,b,c三者shape相同。
代码示例如下:
import torch
a = torch.randint(1,9,[3,4])
a = a.float()
print(a.shape)
print(a)
zeros = torch.zeros_like(a[0,:])
ones = torch.ones_like(a[0,:])
output = a.clone()
for n in range(a.shape[0]):
output[n, :] = torch.where(a[n, :] >= torch.mean(a[n, :]), ones, zeros)
print(output.shape)
print(output)
输出结果如图:
可以看出,选出了每一维度比均值大的值的位置。可以与torch.index_select搭配使用。
可能的报错:
RuntimeError: expand(torch.cuda.FloatTensor{[12, 16]}, size=[16]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)
原因:三个参数shape不相同