pytorch学习11:where 和 gather

where

where语法:torch.where(condition, x, y)

  • 要求:condition、x、y矩阵的阶数(shape)必须相同,返回值也是相同阶数的矩阵
  • condition:条件矩阵,当元素为 True 时,填入 x 中对应的元素,否则填入 y 中对应的元素
  • x:condition中元素为 True 时将从 x 中选取元素
  • y:condition中元素为 False 时将从 y 中选取元素

示例一:c = max(a, b)

import torch

a = torch.randn(3, 3)
b = torch.randn(3, 3)

print("a:\n{}\n".format(a))
print("b:\n{}\n".format(b))

c = torch.where(a>b, a, b)

print("c:\n{}\n".format(c))

请添加图片描述

示例二:二分类预测

import torch

"""
假设 p 为二分类预测后的概率
当 p>0.5 时标识为 1
否则 标识为 0
"""

p = torch.rand(3, 3)

print("p:\n{}\n".format(p))

c = torch.where(p>0.5, 1, 0)

print("c:\n{}\n".format(c))

请添加图片描述

gather

gather语法:torch.gather(input, dim, index),下面的变量说明是按自己理解写的

  • input: 替换词典矩阵
  • dim: 替换方向
  • index: 需要替换的矩阵,类型必须为 int64

一维

示例:

import torch

_index = torch.trunc(torch.rand(8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 0, _index)
print("_output:\n{}\n".format(_output))

请添加图片描述

  • 要求 input 的列数大于 index 中元素的最大值。
  • 可以理解为把 index 中元素看作 input 的索引,用 input 中相应的元素替换 index,然后返回替换后的矩阵。
  • 返回的矩阵与 index 阶数(shape)相同

二维

dim=1

示例一:input相同

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

# 将 input 扩展为 2x10
_input = _input.expand(2, 10)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 1, _index)
print("_output:\n{}\n".format(_output))

请添加图片描述

  • 要求 input 的行数与 index 相同,列数要求与一维一致
  • 可以理解为两个一维的gather
  • 对 index 的两行数据进行同中替换

示例二:input不同

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
     [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])

print("_index:\n{}\n".format(_index))

print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 1, _index)
print("_output:\n{}\n".format(_output))

请添加图片描述

  • 要求input行列与 上个例子 一致
  • 可以理解为对 index 的两行以不同的标准进行替换

dim=0

示例:

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

# 将 input 切换为列向量,且将列扩展到与 index 列相同
_input = _input.view(10, 1).expand(10, 8)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 0, _index)
print("_output:\n{}\n".format(_output))

请添加图片描述

  • dim=1 很类似,可以理解为对每列进行替换
  • input 的要求也与 dim=1 一致

三维

import torch

_index = torch.trunc(torch.rand(2, 2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

_input = _input.expand(2, 2, 10)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 2, _index)
print("_output:\n{}\n".format(_output))

请添加图片描述

  • 与一维和二维很类似

关于 gather 的一些理解

  • input 的要求为:

    维数与index一致,且除了 dim 指定的维度外,每个维度的长度都要与 index 相应维度相同。

    dim指定的维度长度要大于 index 元素的最大值。

  • 关于 input 每 行/列 是否相同:

    虽然可以自定义让 input 不同 行/列 之间有所区别,从而让 index 替换的内容产生变化。但这样做的意义不大,且在高维度的时候自定义 input 会很麻烦。个人感觉用的最多的还是让 input 对所有要替换的 index 做同一替换。

  • 关于 gather:

    gather的使用可以理解为 中文汉字和拼音的替换 ,一维表示对一个句子进行替换,二维表示对一篇文章进行替换,三维表示对一本书进行替换。如:[我, 爱, 你] → [wo, ai, ni]。

总结

where 和 gather 的功能我们都可以通过自己写代码来实现,但 where 和 gather 底层优于自己写的代码,因此速度快很多。因此能用 where 和 gather 就尽量用。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值