案例:
import torch
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
out = torch.where(a > 0.5, a, b)
print(out)
print('*' * 100)
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0, index=torch.tensor([0, 3]))
print(out)
print(out.shape)
print('*' * 100)
a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
out = torch.gather(a, dim=0, index=torch.tensor([[0,1,1,1],
[0,1,2,1]]))
print(out)
print('*' * 100)
a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
mask = torch.gt(a, 8)
print(mask)
out = torch.masked_select(a, mask)
print(out)
print('*' * 100)
a = torch.linspace(1, 16, 16).view(size=[4, 4])
print(a)
out = torch.take(a, index=torch.tensor([0,2]))
print(out)
# exit()
print('*' * 100)
a = torch.tensor([[0, 1], [3, 5]])
print(a)
out = torch.nonzero(a)
print(out)
"""
tensor([[0.8969, 0.0789, 0.3952, 0.9113],
[0.7224, 0.0376, 0.0088, 0.6201],
[0.0165, 0.6198, 0.6357, 0.1959],
[0.2630, 0.7964, 0.9304, 0.4399]])
tensor([[0.9669, 0.6546, 0.5789, 0.4978],
[0.2343, 0.1951, 0.2673, 0.4190],
[0.6410, 0.3289, 0.6181, 0.8436],
[0.0032, 0.4834, 0.2422, 0.2723]])
tensor([[0.8969, 0.6546, 0.5789, 0.9113],
[0.7224, 0.1951, 0.2673, 0.6201],
[0.6410, 0.6198, 0.6357, 0.8436],
[0.0032, 0.7964, 0.9304, 0.2723]])
*****************************************************************************************
tensor([[0.2883, 0.7950, 0.1790, 0.0426],
[0.7390, 0.5961, 0.6238, 0.5012],
[0.3688, 0.5733, 0.4973, 0.2533],
[0.8975, 0.1249, 0.0760, 0.1148]])
tensor([[0.2883, 0.7950, 0.1790, 0.0426],
[0.8975, 0.1249, 0.0760, 0.1148]])
torch.Size([2, 4])
*****************************************************************************************
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[ 1., 6., 7., 8.],
[ 1., 6., 11., 8.]])
*****************************************************************************************
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[False, False, False, False],
[False, False, False, False],
[ True, True, True, True],
[ True, True, True, True]])
tensor([ 9., 10., 11., 12., 13., 14., 15., 16.])
*****************************************************************************************
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([1., 3.])
*****************************************************************************************
tensor([[0, 1],
[3, 5]])
tensor([[0, 1],
[1, 0],
[1, 1]])
Process finished with exit code
"""