- 问题描述:
>>> b.type()
'torch.FloatTensor'
>>> b.int()
tensor([[ 0, 1, -1, 0, 0],
[ 0, 0, 1, 0, 2]], dtype=torch.int32)
>>> b
tensor([[-0.2353, 1.5432, -1.4633, -0.1470, 0.5092],
[ 0.9389, 0.0459, 1.0843, -0.7417, 2.2712]])
>>> a
tensor([[-0.6341, 1.6152, -0.7700, -0.0680, 0.6076],
[-0.0076, 0.6456, 1.0000, -0.0277, -0.3403]])
>>> b.masked_fill_(a.int(),2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected object of scalar type Byte but got scalar type Int for argument #2 'mask'
原因是masked_fill_(a.int(),2)这个函数的输入a.int()为int()类型,应该改为byte()类型
- 解决方法
>>> a.byte()
tensor([[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]], dtype=torch.uint8)
>>> b.int().masked_fill_(a.byte(),2)
tensor([[ 0, 2, -1, 0, 0],
[ 0, 0, 2, 0, 2]], dtype=torch.int32)
>>> b.int()
tensor([[ 0, 1, -1, 0, 0],
[ 0, 0, 1, 0, 2]], dtype=torch.int32)
>>> a.byte()
tensor([[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]], dtype=torch.uint8)
即可