一、逻辑运算符
- ==:判断两个元素是否相等。
- !=:判断两个元素是否不相等。
- <:判断左侧元素是否小于右侧元素。
- <=:判断左侧元素是否小于等于右侧元素。
- >:判断左侧元素是否大于右侧元素。
- >=:判断左侧元素是否大于等于右侧元素。
a = [1, 2, 3]
b = [2, 2, 3]
c = [x == y for x, y in zip(a, b)] # 逐元素相等比较
# 输出:[False, True, True]
d = [x < y for x, y in zip(a, b)] # 逐元素小于比较
# 输出:[True, False, False]
二、NumPy数组操作
- np.equal():逐元素相等比较。
- np.not_equal():逐元素不相等比较。
- np.less():逐元素小于比较。
- np.less_equal():逐元素小于等于比较。
- np.greater():逐元素大于比较。
- np.greater_equal():逐元素大于等于比较。
import numpy as np
a = np.array([1, 2, 3])
b = np.array([2, 2, 3])
c = np.equal(a, b) # 逐元素相等比较
# 输出:array([False, True, True])
d = np.less(a, b) # 逐元素小于比较
# 输出:array([True, False, False])
3. PyTorch的逐元素比较
- torch.eq():逐元素相等比较。
- torch.ne():逐元素不相等比较。
- torch.lt():逐元素小于比较。
- torch.le():逐元素小于等于比较。
- torch.gt():逐元素大于比较。
- torch.ge():逐元素大于等于比较。
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([2, 2, 3])
c = torch.eq(a, b) # 逐元素相等比较
print(c) # 输出:tensor([False, True, True])
d = torch.lt(a, b) # 逐元素小于比较
print(d) # 输出:tensor([True, False, False])
举例子
masked_fill 是 PyTorch 中的一个函数,用于根据给定的遮罩(mask)对张量进行填充操作。函数的作用是将输入张量中,所有在 mask 中对应位置为 True 的元素替换为 value。
masked_fill(mask, value)
- mask:用于指示填充位置的布尔类型遮罩张量。
- value:填充的值,可以是标量或与输入张量相同形状的张量。
import torch
x = torch.tensor([1, 2, 3, 4, 5])
mask = torch.tensor([True, False, True, False, False])
filled = x.masked_fill(mask, 0)
print(filled) # 输出:tensor([0, 2, 0, 4, 5])