版本: pytorch1.0
目的: tensor中的index与mask
例子
下面针对torch.long
与torch.uint8
数据类型在index/mask 中的不同作用进行分析
t = torch.rand(4,2)
"""
tensor([[0.5492, 0.2083],
[0.3635, 0.5198],
[0.8294, 0.9869],
[0.2987, 0.0279]])
"""
# 注意数据类型是 uint8,
mask= torch.ones(4,dtype=torch.uint8)
mask[2] = 0
print(mask)
print(t[mask, :])
"""
tensor([1, 1, 0, 1], dtype=torch.uint8)
tensor([[0.5492, 0.2083],
[0.3635, 0.5198],
[0.2987, 0.0279]])
"""
# 注意, 数据类型是long
idx = torch.ones(3,dtype=torch.long)
idx[1] = 0
print(idx)
print(t[idx, :])
"""
tensor([1, 0, 1])
tensor([[0.3635, 0.5198],
[0.5492, 0.2083],
[0.3635, 0.5198]])
"""
结论
- 当
mask
的数据类型是torch.uint8
时,此时的tensor用作mask,tensor中的1对应的行/列保留,0对应的行/列舍去。且被mask的维度必须与原始tensor的维度一致。其实很好理解,因为你是一个mask,是要覆盖在原始tensor上面的,因此需要你和原始tensor保持一致的dimension。上面的例子中,需要保证mask.size(0)
==t.shape(0)
,否则会报错。 - 当
idx
的数据类型是torch.long
时,此时的tensor用作index,tensor中的每个数字代表着将要取出的tensor的行列索引。用作index时是为了从原始的tensor中取出指定的行列,因此,取出多少不受限。就上面的例子而言不需要保证idx.size(0)
==t.shape(0)
。