高阶操作
where
torch.where(condition,x,y)
(条件,A矩阵,B矩阵),符合条件从A对应位置选数,不符合就从B对应位置选数
条件选择函数
a=torch.full([2,2],0)
b=torch.full([2,2],1)
a,b
'''
(tensor([[0, 0],
[0, 0]]),
tensor([[1, 1],
[1, 1]]))
'''
#条件
cond=torch.rand(2,2)
cond
'''
tensor([[0.6961, 0.8969],
[0.2795, 0.9759]])
'''
torch.where(cond>0.5,a,b)
#如果cond的值>0.5就选取a对应位置的数,不是就选取b中对应位置的数
'''
tensor([[0, 0],
[1, 0]])
'''
gather
torch.gather(input,dim,index,out=None)-Tensor
input,输入数据
dim,查看维度
index,查看索引
out=None
prob=torch.randn(4,10)
'''
tensor([[ 0.6805, -0.4651, 0.6448, 0.6679, -0.5646, 2.3565, 0.9479, -0.0406,
-0.4645, 1.3624],
[ 0.8647, -0.5109, 0.5100, 0.6534, -0.8373, -1.8661, -0.8300, -0.0230,
-0.2076, 0.6472],
[ 0.9843, 1.0484, 0.1264, -1.2768, 0.7247, 0.9827, 1.1230, 0.9566,
0.4962, -0.9180],
[ 1.3375, 0.7297, -0.8324, 0.5294, -1.7625, 0.7328, 0.9702, -0.0741,
2.6688, 0.1584]])
'''
#得到按第一维排序的top3的数的大小以及位置,输出形式与原来的数一样
idx=prob.topk(dim=1,k=3)
idx
'''
torch.return_types.topk(
values=tensor([[2.3565, 1.3624, 0.9479],
[0.8647, 0.6534, 0.6472],
[1.1230, 1.0484, 0.9843],
[2.6688, 1.3375, 0.9702]]),
indices=tensor([[5, 9, 6],
[0, 3, 9],
[6, 1, 0],
[8, 0, 6]]))
'''
label=torch.arange(10)+100
label
'''
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
'''
idx=idx[1]#只得到第二部分即对应数的位置
idx
'''
tensor([[5, 9, 6],
[0, 3, 9],
[6, 1, 0],
[8, 0, 6]])
'''
idx.long()
'''tensor([[5, 9, 6],
[0, 3, 9],
[6, 1, 0],
[8, 0, 6]])
'''
#按idx的下标在label.expand中查找对应的数
torch.gather(label.expand(4,10),dim=1,index=idx.long())
'''
tensor([[105, 109, 106],
[100, 103, 109],
[106, 101, 100],
[108, 100, 106]])
'''