pytorch用masked_fill_函数的时候报错:RuntimeError: masked_fill_ only supports boolean masks, but got mask with dtype unsigned char
masked_fill_(mask, value)
Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.
Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with
该函数的用法如下,其中a和mask的维数要一致,但是mask的不同维度上的取值要么和a一致,要么取1,该例子只展示mask的第2维取值1的情况(维度从0开始计数)
import torch
a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]])
print(a)
print(a.size())
print("#############################################3")
mask = torch.ByteTensor([[[1],[1],[0]],[[0],[1],[1]]])
print(mask.size())
b = a.masked_fill(mask, value=torch.tensor(-1e9))
print(b)
print(b.size())
网上大多例子都是这样的,但是该例运行失败,(我运行失败了,不知道有没有可能成功)报如上的错误,即:RuntimeError: masked_fill_ only supports boolean masks, but got mask with dtype unsigned char
解决方式
替换b = a.masked_fill(mask, value=torch.tensor(-1e9))为b = a.masked_fill(mask.bool(), value=torch.tensor(-1e9))