pytorch 根据bool矩阵取出tensor中对应位置元素

1. bool矩阵当做索引(类型是:BoolTensor)

结果为一维向量(因为bool矩阵二维的,根据bool矩阵中True对应位置,把tensor数据中相应位置中的值取出来,组成一个新的一维tensor向量)

#布尔索引  用布尔索引总是会返回一份新创建的数据,原本的数据不会被改变。
a2 = np.arange(15).reshape(3,5)
print('a2===',a2)
mask = a2<5
b2 = a2[mask]
print('b2===',b2)
b2[0] = 17
print('a2===',a2) #修改b2中的数据,会发现原数据a2中的值没有发生改变。

输出结果:

a2=== tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
b2=== tensor([0, 1, 2, 3, 4])
a2=== tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

2. 一维bool向量当做索引(类型是:BoolTensor)

第一个例子:结果为一维向量(target是一维的)

target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]
 
print(target)
print(mask)
print(masked_target)

输出结果如下:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False,  True, False, False,  True])
tensor([1., 2., 3.])

第二个例子:结果为一个二维的向量(x是二维的)

x = torch.randn(4,3)
y = torch.tensor([True,False,False,True])
print('x===',x)
print('y===',y)
print('x[y]===',x[y])
print('y.sum()===',y.sum())

输出结果:

x===tensor([[ 0.4563,  0.3963,  0.4101],
        [-0.4360, -0.2968,  1.0010],
        [ 0.2851,  0.0890,  0.5452],
        [ 0.8384, -1.1912,  0.2131]])
y===tensor([ True, False, False,  True])
x[y]===tensor([[ 0.4563,  0.3963,  0.4101],
        [ 0.8384, -1.1912,  0.2131]])
y.sum()===tensor(2)

3. torch.masked_select() 根据bool矩阵取出tensor中对应位置元素(类型是:BoolTensor)。

torch.masked_select(input, mask, *, out=None) → Tensor

此方法中mask是一个bool矩阵,在input中取出mask中True对应的值。
首先介绍参数:

  • input(tensor):需要进行处理的tensor。
  • mask(BoolTensor):包含了二进制掩码,要进行索引的tensor。
  • out:输出的结果tensor(结果为一维tensor)。

使用方法如下:
第一个例子:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
        [-1.2035,  1.2252,  0.5002,  0.6248],
        [ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
        [False, True, True, True],
        [False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

第二个例子:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)
 
print(target)
print(mask)
print(masked_target)

输出结果为:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False,  True, False, False,  True])
tensor([1., 2., 3.])

使用方法其实挺明显的,就是把input与mask相对应起来,取出mask中True所对应位置的数据,组成一维的tensor

注意:mask和input的形状可以不相同,但是它们必须是可以广播的。并且返回tensor和原tensor使用不同的内存,相互独立

4. 参考链接

  1. Python基础 切片索引、布尔索引、花式索引
  2. Pytorch中Tensor的索引,切片以及花式索引(fancy indexing)
  3. pytorch中bool类型的张量作为索引
  4. pytorch中几种tensor掩码的获取方法(含代码)
  5. pytorch每日一学35(torch.masked_select())根据bool矩阵取出tensor中对应位置元素
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值