pytorch 的 where 有两种用法:
- torch.where
(condition, x, y)
- torch.Tensor.where
(condition, y)
第二种等同于 torch.where(condition, self, y),而不是 torch.where(condition, y, self)。即 condition 满足时不变,不满足时才取 y。
之前实际用的时候理解反了,以为是 按 condition 到 y 取值替换 self,是错的。而想达到这个效果,要对 condition 取反。
Code
import torch
a = torch.arange(6) # 原数据,即 self
print("a:", a) # [0, 1, 2, 3, 4, 5]
b = torch.arange(10, 16) # 替换数据,即 y
print("b:", b) # [10, 11, 12, 13, 14, 15]
c = a % 2 == 0 # condition
print("c:", c) # [ True, False, True, False, True, False]
print("TODO:按 c 从 b 取值替换 a")
d = a.where(c, b) # 掉坑用法
print("d:", d) # [ 0, 11, 2, 13, 4, 15]
e = a.where(c.logical_not(), b) # 正解
print("e:", e) # [10, 1, 12, 3, 14, 5]
- 输出
a: tensor([0, 1, 2, 3, 4, 5])
b: tensor([10, 11, 12, 13, 14, 15])
c: tensor([ True, False, True, False, True, False])
TODO:按 c 从 b 取值替换 a
d: tensor([ 0, 11, 2, 13, 4, 15])
e: tensor([10, 1, 12, 3, 14, 5])