import torch
#torch.where
a = torch.rand(4,4)
b = torch.rand(4,4)
out = torch.where(a>0.5,a,b)
print(a)
print(b)
print(out)
tensor([[0.9875, 0.0399, 0.4628, 0.0218],
[0.6645, 0.6531, 0.6494, 0.4461],
[0.0189, 0.6904, 0.1290, 0.1731],
[0.0277, 0.1196, 0.8279, 0.6042]])
tensor([[0.0552, 0.7499, 0.3638, 0.6014],
[0.0851, 0.3886, 0.7147, 0.1416],
[0.6719, 0.7295, 0.0700, 0.0164],
[0.7201, 0.5614, 0.7305, 0.7452]])
tensor([[0.9875, 0.7499, 0.3638, 0.6014],
[0.6645, 0.6531, 0.6494, 0.1416],
[0.6719, 0.6904, 0.0700, 0.0164],
[0.7201, 0.5614, 0.8279, 0.6042]])
#torch.index_select()
a = torch.rand(4,4)
#b = torch.rand(4,4)
print(a)
print(torch.index_select(a,dim = 0,index=torch.tensor([0,3,2])))
print(torch.index_select(a,dim = 1,index=torch.tensor([0,3,2])))
tensor([[0.2547, 0.1544, 0.5950, 0.1760],
[0.3791, 0.4782, 0.5972, 0.6868],
[0.4553, 0.2166, 0.5993, 0.2549],
[0.4715, 0.8413, 0.7631, 0.6736]])
tensor([[0.2547, 0.1544, 0.5950, 0.1760],
[0.4715, 0.8413, 0.7631, 0.6736],
[0.4553, 0.2166, 0.5993, 0.2549]])
tensor([[0.2547, 0.1760, 0.5950],
[0.3791, 0.6868, 0.5972],
[0.4553, 0.2549, 0.5993],
[0.4715, 0.6736, 0.7631]])
#torch.gather
a = torch.linspace(1,16,16).view(4,4)
print(a)
out = torch.gather(a, dim = 0,
index = torch.tensor([[0,1,1,1],
[0,1,2,2],
[0,1,3,3]]))
print(out)
print(out.shape)
#dim = 0,out[i,j,k] = input[index[i,j,k],j,k]
#dim = 1,out[i,j,k] = input[i,index[i,j,k],k]
#dim = 0,out[i,j,k] = input[i,j,index[i,j,k],j,k]
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[ 1., 6., 7., 8.],
[ 1., 6., 11., 12.],
[ 1., 6., 15., 16.]])
torch.Size([3, 4])
a = torch.linspace(1,16,16).view(4,4)
mask = torch.gt(a,8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[False, False, False, False],
[False, False, False, False],
[ True, True, True, True],
[ True, True, True, True]])
tensor([ 9., 10., 11., 12., 13., 14., 15., 16.])
#torch.take
a = torch.linspace(1,16,16).view(4,4)
b = torch.take(a,index = torch.tensor([0,15,13,10]))
print(b)
tensor([ 1., 16., 14., 11.])
#torch.nonzero
a = torch.tensor([[0,1,2,0],[2,3,0,1]])
out = torch.nonzero(a)
print(a)
print(out)
#稀疏表示
tensor([[0, 1, 2, 0],
[2, 3, 0, 1]])
tensor([[0, 1],
[0, 2],
[1, 0],
[1, 1],
[1, 3]])