where(来源函数)
torch.where(condition,x,y)->tensor
当满足condition,则来自于a,反之来自b
import torch
cond=torch.tensor([[0.6769,0.7271],[0.8884,0.4163]])
a=torch.tensor([[0,0],[0,0]])
b=torch.tensor([[1,1],[1,1]])
torch.where(cond>0.5,a,b)
得到结果
tensor([[0, 0],
[0, 1]])
输出为0的代表来源为a,输出为1的代表来源为b
gather(查表的过程)
torch.gather(input,dim,index,out=None)->tensor
例
就像是给了数据以后,查表得到对应参数,再收集回来进行输出。
gather函数即为gather(对应的参数表,dim,数据表)
import torch
prob=torch.randn(4,10)
idx=prob.topk(dim=1,k=3)
idx=idx[1]
label=torch.arange(10)+100
torch.gather(label.expand(4,10),dim=1,index=idx.long())
每一次函数对应参数为
输出结果:
tensor([[101, 105, 107],
[109, 102, 106],
[103, 105, 108],
[102, 104, 100]])