mask_fill的整体意思是使用 value 填充mask值为True的位置,所以需要注意mask的值,有时候需要~取反
masked_fill_(mask, value) - 函数名后面加下划线。in-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值,可以称为原地操作符。
masked_fill(mask, value) -> Tensor - 函数名后面没有下划线。out-of-place version 在 PyTorch 中是指当改变一个 tensor 的值的时候,经过复制操作,不是直接在原来的内存上改变它的值,而是修改复制的 tensor。
masked_select 函数最关键的参数就是布尔掩码 mask,传入 mask 参数的布尔张量通过 True 和 False (或 1 和 0) 来决定输入张量对应位置的元素是否保留,既然是一一对应的关系,这就需要传入 mask 中的布尔张量和传入 input 中的输入张量形状要相同。这里需要注意此时的形状相同包括显式的相等,还包括隐式的相等。
masked_select 函数虽然简单,但是有几点需要注意:
1、使用 masked_select 函数返回的结果都是 1D 张量,张量中的元素就是被筛选出来的元素值;
2、传入 input 参数中的输入张量和传入 mask 参数中的布尔张量形状可以不一致,但是布尔张量必须要能够通过广播机制扩展成和输入张量相同的形状;
参考:
https://zhuanlan.zhihu.com/p/348035584
https://blog.csdn.net/chengyq116/article/details/106961087