pytorch中index_select(),masked_select(),gather()使用

index_select()

index_select(input, dim, index)
在指定维度dim上选取,比如选取某些行、某些列

input中输入的一个该是一个张量(tensor)
dim代表选取的维度:0代表行,1代表列
后面的张量代表的是指定的行或列

input.index_select(dim, index)此写法与index_select(input, dim, index)效果相同

x = torch.tensor([[2,1,3,0],[1,2,0,4],[4,3,2,1]])
print(x)

b = torch.index_select(x,0,torch.tensor([0,1]))
print(b)

print(x.index_select(0,torch.tensor([0,1])))

c = torch.index_select(x,1,torch.tensor([0,1]))
print(c)

输出如下:

tensor([[2, 1, 3, 0],
        [1, 2, 0, 4],
        [4, 3, 2, 1]])

tensor([[2, 1, 3, 0],
        [1, 2, 0, 4]])

tensor([[2, 1, 3, 0],
        [1, 2, 0, 4]])
        
tensor([[2, 1],
        [1, 2],
        [4, 3]])

masked_select()

masked_select(input, mask)
mask就相当于检索条件应该是一个值为True或False的张量(tensor)
如下:

#定义一个3*4的随机张量
x = torch.randn(3,4)
print(x)

#定义mask大于0的为True否则为False
mask = x.ge(0)
print(mask)

print(torch.masked_select(x,mask)) 

输出:

tensor([[ 0.2216, -1.0851, -1.8103,  1.5319],
        [ 0.5802,  0.4164, -1.4859, -2.3838],
        [-0.1804,  0.0458, -1.2227, -0.7634]])
        
tensor([[ True, False, False,  True],
        [ True,  True, False, False],
        [False,  True, False, False]])
        
tensor([0.2216, 1.5319, 0.5802, 0.4164, 0.0458])

gather()

gather(input, dim, index)
根据index,在dim维度上选取数据,输出的size与index一样

index维度与input的维度是一致的
dim = 0按页检索
dim = 1按行检索
dim = 2按列检索

dim = 1
如下:

a = torch.randint(0, 30, (2, 3, 5))
print(a)

index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                          
                         [[1,2,2,2,2],
                          [0,0,0,0,0],
                          [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)

当dim = 1时 index张量里的数值第一行[0,1,2,0,2]就代表在张量a 上半矩阵中第零列取第0个元素,第一列取第1个元素,第二列取第2个元素,第三列取第0个元素,第四列取第2个元素…接下来的以此类推就能得到。
输出:

tensor([[[17,  4,  3, 22, 20],
         [11,  8, 10,  7, 11],
         [ 4,  4, 17, 23,  8]],

        [[ 2, 19,  0, 12, 28],
         [28, 11,  7, 26, 16],
         [22, 12, 19, 13,  9]]])
True
tensor([[[17,  8, 17, 22,  8],
         [17,  4,  3, 22, 20],
         [11,  8, 10,  7, 11]],

        [[28, 12, 19, 13,  9],
         [ 2, 19,  0, 12, 28],
         [22, 12, 19, 13,  9]]])

dim = 2
如下:

a = torch.randint(0, 30, (2, 3, 5))
print(a)

index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                          
                         [[1,2,2,2,2],
                          [0,0,0,0,0],
                          [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 2,index)
print(b)

当dim = 2时 index张量里的数值第一行[0,1,2,0,2]就代表在张量a 上半矩阵中第零行取第0个元素,第零行取第1个元素,第零行取第2个元素,第零行取第0个元素,第零行取第2个元素…接下来的以此类推就能得到。
输出:

tensor([[[13, 24, 18, 23, 14],
         [ 7, 20, 10, 24, 11],
         [10, 19, 19,  8,  5]],

        [[ 5, 26, 15, 27,  4],
         [29,  6,  2,  4, 25],
         [12, 10, 28, 29, 24]]])
True
tensor([[[13, 24, 18, 13, 18],
         [ 7,  7,  7,  7,  7],
         [19, 19, 19, 19, 19]],

        [[26, 15, 15, 15, 15],
         [29, 29, 29, 29, 29],
         [28, 28, 28, 28, 28]]])

dim = 0
如下:

a = torch.randint(0, 30, (2, 3, 5))
print(a)

index = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                          
                         [[1,0,0,0,0],
                          [0,0,0,0,0],
                          [1,1,0,0,0]]])
print(a.size()==index.size())
b = torch.gather(a, 0,index)
print(b)

当dim = 0时 index中的0和1代表的是在该位置取第0页还是第1页中该位置的元素。
输出:

tensor([[[14, 29,  5, 29,  3],
         [27, 15, 19, 27,  9],
         [27,  4,  9,  5,  6]],

        [[29, 27, 14,  2,  6],
         [14,  6, 11,  7, 22],
         [18,  6, 14,  1, 18]]])
True
tensor([[[14, 27, 14, 29,  6],
         [27,  6, 11,  7, 22],
         [18,  6, 14,  1, 18]],

        [[29, 29,  5, 29,  3],
         [27, 15, 19, 27,  9],
         [18,  6,  9,  5,  6]]])
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值