pytorc torch.uint8与torch.long/ torch. float

版本: pytorch1.0
目的: tensor中的index与mask


例子

下面针对torch.longtorch.uint8数据类型在index/mask 中的不同作用进行分析

t = torch.rand(42)
"""
tensor([[0.5492, 0.2083],
        [0.3635, 0.5198],
        [0.8294, 0.9869],
        [0.2987, 0.0279]])
"""
# 注意数据类型是 uint8, 
mask= torch.ones(4,dtype=torch.uint8)
mask[2] = 0
print(mask)
print(t[mask, :])
"""
 tensor([1, 1, 0, 1], dtype=torch.uint8)
 
 tensor([[0.5492, 0.2083],
        [0.3635, 0.5198],
        [0.2987, 0.0279]]) 
"""

# 注意, 数据类型是long
idx = torch.ones(3,dtype=torch.long)
idx[1] = 0
print(idx)
print(t[idx, :])
"""
tensor([1, 0, 1])
tensor([[0.3635, 0.5198],
        [0.5492, 0.2083],
        [0.3635, 0.5198]])
"""

结论

  • mask的数据类型是torch.uint8时,此时的tensor用作mask,tensor中的1对应的行/列保留,0对应的行/列舍去。且被mask的维度必须与原始tensor的维度一致。其实很好理解,因为你是一个mask,是要覆盖在原始tensor上面的,因此需要你和原始tensor保持一致的dimension。上面的例子中,需要保证 mask.size(0) == t.shape(0),否则会报错。
  • idx的数据类型是torch.long时,此时的tensor用作index,tensor中的每个数字代表着将要取出的tensor的行列索引。用作index时是为了从原始的tensor中取出指定的行列,因此,取出多少不受限。就上面的例子而言不需要保证 idx.size(0) == t.shape(0)
  • 9
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值