PyTorch学习笔记——(2)PyTorch中where()函数和gather()函数的介绍

1、torch.where()

  • torch.where(condition, x, y)---->tensor
    • 功能:就是个三目运算符,第一个参数为条件,这个条件为真则为x,否则为y

例子:

a = torch.ones(2,2)
b = torch.zeros_like(a)
cond = torch.randn(2,2) # 这个是构成条件的矩阵

cond = tensor([[-1.0151,  0.9591],
        [-0.8334, -1.2766]])

torch.where(cond>0.5, a, b) # 当cond里面的元素>0.5时返回a对应位置的元素,否则为b对应位置的元素
# 输出:
tensor([[0., 1.],
        [0., 0.]])

2、torch.gather()

  • torch.gather(input, dim, index)
    • 功能:这个作用就是从输入input中,选择dim维度的数据,再利用索引index来找到这些数据。

例子:

input = torch.randn(3,3) # 生成输入
input = 
tensor([[-0.1183, -0.7858, -0.6806],
        [-1.0750, -0.8349, -0.4402],
        [-0.5234, -0.8766,  0.4833]])

index = torch.arange(3).view(-1, 3) # 生成索引
index = tensor([[0, 1, 2]])

# 在0维(也就是行)索引查找input的值
torch.gather(input, dim=0, index=index) 
# 输出:
tensor([[-0.1183, -0.8349,  0.4833]])

# 在1维(也就是列)索引查找input的值
torch.gather(input, dim=0, index=index) 
# 输出:
tensor([[-0.1183, -0.7858, -0.6806]])

若有用,欢迎点赞,若有错,请指正,谢谢!

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值