PyTorch学习(10):torch.where

PyTorch学习(1):torch.meshgrid的使用-CSDN博客

PyTorch学习(2):torch.device-CSDN博客

PyTorch学习(9):torch.topk-CSDN博客


目录

1. 写在前面

2. torch.where 函数概述

3. 基本用法示例


1. 写在前面

        在PyTorch中,torch.where 函数是一种强大的工具,用于根据条件从两个张量中选择元素。这个函数在数据处理和神经网络的正则化技术中特别有用,比如实现dropout或者masking策略。torch.where 的基本用法类似于Python的内建函数 np.where,但它专门针对PyTorch张量进行了优化。

2. torch.where 函数概述

        在PyTorch中,torch.where 函数是一种强大的工具,用于根据条件从两个张量中选择元素。这个函数在数据处理和神经网络的正则化技术中特别有用,比如实现dropout或者masking策略。torch.where 的基本用法类似于Python的内建函数 np.where,但它专门针对PyTorch张量进行了优化。

        torch.where 函数接受三个参数:

    condition:一个布尔型张量,指定了从另外两个张量中选择元素的条件。

    x:第一个张量,当 condition 为 True 时,其对应位置的元素将被选中。

    y:第二个张量,当 condition 为 False 时,其对应位置的元素将被选中。

torch.where 返回一个新的张量,其中包含了根据 condition 从 x 和 y 中选择的元素。

3. 基本用法示例

        下面是一个简单的例子,展示了如何使用 torch.where 来从一个张量中选择大于某个阈值的元素,并用另一个张量的相应元素替换它们。

import torch

# 创建两个张量

x = torch.tensor([1, 4, 3, 7, 2])

y = torch.tensor([0.5, 0.6, 0.7, 0.8, 0.9])

# 定义一个条件张量

condition = x > 3

# 使用torch.where根据条件选择元素

result = torch.where(condition, y, x)

print("原始张量 x:", x)

print("替换张量 y:", y)

print("条件张量 (x > 3):", condition)

print("torch.where 结果:", result)

        运行结果如下。

原始张量 x: tensor([1., 4., 3., 7., 2.])

替换张量 y: tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000])

条件张量 (x > 3): tensor([False,  True, False,  True, False])

torch.where 结果: tensor([1.0000, 0.6000, 3.0000, 0.8000, 2.0000])

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值