where
torch.where(condition, x y) --->Tensor
对条件判断,满足的选择填充对应X位置元素,否则Y位置
a = torch.rand(2,2)
b = torch.tensor([[0,0],
[0,0]])
c = torch.tensor([[1,1],[1,1]])
d = torch.where(a>0.5,b,c) #对条件判断,满足的获取b位置元素
print(a)
print(d)
tensor([[0.7284, 0.9584],
[0.0740, 0.3946]]) #a
tensor([[0, 0],
[1, 1]])
gather
torch.gather( input, dim , index, out = None) ---->Tensor
收集操作: 输入一个tensor,指定维度,索引,返回一个和索引shape一样的tensor
mn = torch.randn(4,10)
idx = mn.topk(k=3,dim=1)
idx = idx[1] #索引
print(mn),print(idx)
label = torch.arange(10)+100
p = torch.gather(label.expand(4,10),dim=1,index=idx.long()) #4行数从100-109
# print(idx.long()) #转换为longtensor数据格式 #按照索引取数据,
print(p)
tensor([[ 0.9868, 1.3073, -0.3827, -1.0585, 0.0805, -1.3429, 0.6678, -0.3388,
-0.4304, 0.2057],
[ 0.1931, -0.6879, -0.1194, 1.4844, -0.4510, -0.4621, -0.9452, 1.1003,
-1.0377, -1.0391],
[ 0.6447, -0.1251, -0.7113, 0.8599, -0.3897, -0.3618, -0.4018, 0.8156,
1.8524, -0.9277],
[-0.5331, 0.1754, 0.2532, 0.3705, 0.6344, 1.5306, -1.2829, -0.2778,
-0.4927, -0.3294]])
tensor([[1, 0, 6], #前三大的索引
[3, 7, 0],
[8, 3, 7],
[5, 4, 3]])
tensor([[101, 100, 106],
[103, 107, 100],
[108, 103, 107],
[105, 104, 103]])