index_select()
index_select(input, dim, index)
在指定维度dim上选取,比如选取某些行、某些列
input中输入的一个该是一个张量(tensor)
dim代表选取的维度:0代表行,1代表列
后面的张量代表的是指定的行或列
input.index_select(dim, index)此写法与index_select(input, dim, index)效果相同
x = torch.tensor([[2,1,3,0],[1,2,0,4],[4,3,2,1]])
print(x)
b = torch.index_select(x,0,torch.tensor([0,1]))
print(b)
print(x.index_select(0,torch.tensor([0,1])))
c = torch.index_select(x,1,torch.tensor([0,1]))
print(c)
输出如下:
tensor([[2, 1, 3, 0],
[1, 2, 0, 4],
[4, 3, 2, 1]])
tensor([[2, 1, 3, 0],
[1, 2, 0, 4]])
tensor([[2, 1, 3, 0],
[1, 2, 0, 4]])
tensor([[2, 1],
[1, 2],
[4, 3]])
masked_select()
masked_select(input, mask)
mask就相当于检索条件应该是一个值为True或False的张量(tensor)
如下:
#定义一个3*4的随机张量
x = torch.randn(3,4)
print(x)
#定义mask大于0的为True否则为False
mask = x.ge(0)
print(mask)
print(torch.masked_select(x,mask))
输出:
tensor([[ 0.2216, -1.0851, -1.8103, 1.5319],
[ 0.5802, 0.4164, -1.4859, -2.3838],
[-0.1804, 0.0458, -1.2227, -0.7634]])
tensor([[ True, False, False, True],
[ True, True, False, False],
[False, True, False, False]])
tensor([0.2216, 1.5319, 0.5802, 0.4164, 0.0458])
gather()
gather(input, dim, index)
根据index,在dim维度上选取数据,输出的size与index一样
index维度与input的维度是一致的
dim = 0按页检索
dim = 1按行检索
dim = 2按列检索
dim = 1
如下:
a = torch.randint(0, 30, (2, 3, 5))
print(a)
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
当dim = 1时 index张量里的数值第一行[0,1,2,0,2]就代表在张量a 上半矩阵中第零列取第0个元素,第一列取第1个元素,第二列取第2个元素,第三列取第0个元素,第四列取第2个元素…接下来的以此类推就能得到。
输出:
tensor([[[17, 4, 3, 22, 20],
[11, 8, 10, 7, 11],
[ 4, 4, 17, 23, 8]],
[[ 2, 19, 0, 12, 28],
[28, 11, 7, 26, 16],
[22, 12, 19, 13, 9]]])
True
tensor([[[17, 8, 17, 22, 8],
[17, 4, 3, 22, 20],
[11, 8, 10, 7, 11]],
[[28, 12, 19, 13, 9],
[ 2, 19, 0, 12, 28],
[22, 12, 19, 13, 9]]])
dim = 2
如下:
a = torch.randint(0, 30, (2, 3, 5))
print(a)
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 2,index)
print(b)
当dim = 2时 index张量里的数值第一行[0,1,2,0,2]就代表在张量a 上半矩阵中第零行取第0个元素,第零行取第1个元素,第零行取第2个元素,第零行取第0个元素,第零行取第2个元素…接下来的以此类推就能得到。
输出:
tensor([[[13, 24, 18, 23, 14],
[ 7, 20, 10, 24, 11],
[10, 19, 19, 8, 5]],
[[ 5, 26, 15, 27, 4],
[29, 6, 2, 4, 25],
[12, 10, 28, 29, 24]]])
True
tensor([[[13, 24, 18, 13, 18],
[ 7, 7, 7, 7, 7],
[19, 19, 19, 19, 19]],
[[26, 15, 15, 15, 15],
[29, 29, 29, 29, 29],
[28, 28, 28, 28, 28]]])
dim = 0
如下:
a = torch.randint(0, 30, (2, 3, 5))
print(a)
index = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
print(a.size()==index.size())
b = torch.gather(a, 0,index)
print(b)
当dim = 0时 index中的0和1代表的是在该位置取第0页还是第1页中该位置的元素。
输出:
tensor([[[14, 29, 5, 29, 3],
[27, 15, 19, 27, 9],
[27, 4, 9, 5, 6]],
[[29, 27, 14, 2, 6],
[14, 6, 11, 7, 22],
[18, 6, 14, 1, 18]]])
True
tensor([[[14, 27, 14, 29, 6],
[27, 6, 11, 7, 22],
[18, 6, 14, 1, 18]],
[[29, 29, 5, 29, 3],
[27, 15, 19, 27, 9],
[18, 6, 9, 5, 6]]])