pytorch torch.where函数

torch.where 是 PyTorch 中用于条件选择的函数。它可以根据一个布尔条件在两个张量中选择元素,从而生成一个新的张量。

函数定义

torch.where(condition, input, other)
参数说明:
  1. condition
    • 一个布尔张量,表示条件判断结果。
    • 形状可以与 inputother 相同,或者可以广播到相同的形状。
  2. input
    • 满足条件时的值来源张量。
  3. other
    • 不满足条件时的值来源张量。
返回值:
  • 返回一个与 conditioninputother 形状兼容的张量。
  • 如果 condition 的某个位置为 True,返回 input 中对应位置的值;否则返回 other 中对应位置的值。

示例

1. 基本用法
import torch

x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([10, 20, 30, 40, 50])
condition = x > 3

print(condition) # tensor([False, False, False,  True,  True])
result = torch.where(condition, x, y)
print(result)  # 输出: tensor([10, 20, 30,  4,  5])

2. 多维张量 

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
condition = a < 3

result = torch.where(condition, a, b)
print(result)
# 输出:
# tensor([[ 1,  2],
#         [30, 40]])

3. 标量支持

inputother 可以是标量,而不是张量。

x = torch.tensor([1, 2, 3, 4, 5])
condition = x > 3

result = torch.where(condition, x, 0)
print(result)  # 输出: tensor([0, 0, 0, 4, 5])

解释:

  • 如果满足条件 x > 3,取 x 的值;否则取标量 0
4. 广播机制

例1: 

如果 conditioninputother 的形状不同,PyTorch 会自动广播使其兼容。

x = torch.tensor([[1, 2], [3, 4]])
condition = torch.tensor([True, False])

result = torch.where(condition, x, 0)
print(result)
# 输出:
# tensor([[1, 2],
#         [0, 0]])

解释:

  • condition 只有两列,通过广播扩展为形状 [2, 2]

例2:(不同维度的广播机制)

1 .

import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5]])
print(x)
condition = torch.tensor([True, False, True])

result = torch.where(condition, x, 0)
print(result)


x:
tensor([[1, 2, 3],
        [3, 4, 5]])
result:
tensor([[1, 0, 3],
        [3, 0, 5]])

 2 .

import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5]])
print(x)
condition = torch.tensor([[True], [False]])
print(condition)
result = torch.where(condition, x, 0)
print(result)


x:
tensor([[1, 2, 3],
        [3, 4, 5]])

condition:
tensor([[ True],
        [False]])

result:
tensor([[1, 2, 3],
        [0, 0, 0]])


常见用途

1. 替换特定值

将张量中大于某值的元素替换为某个固定值:

x = torch.tensor([1, 5, 10, 15])
x_clipped = torch.where(x > 10, 10, x)
print(x_clipped)  # 输出: tensor([ 1,  5, 10, 10])
2. 创建条件张量

使用条件逻辑生成一个新张量:

x = torch.linspace(-1, 1, 5)
y = torch.where(x > 0, 1, -1)
print(x) # tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000])
print(y)  # 输出: tensor([-1, -1, -1,  1,  1])

注意事项

  1. 数据类型一致性
    inputother 必须具有相同的数据类型,否则会抛出错误。

    x = torch.tensor([1.0, 2.0])
    y = torch.tensor([1, 2])
    torch.where(x > 1, x, y)  # 会报错,因为 x 是浮点型,y 是整型
    

    2. 广播机制
    当使用广播时,确保张量可以广播到相同的形状。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值