Pytorch 一些高阶操作

1. where

使用C=torch.where(condition,A,B)其中A,B,C,condition是shape相同的Tensor,C中的某些元素来自A,某些元素来自B,这由condition中对应位置的元素是1还是0来决定。如果condition对应位置元素是1,则C中的该位置的元素来自A中的该位置的元素,如果condition对应位置元素是0,则C中的该位置的元素来自B中的该位置的元素。

示例代码:

import torch

cond = torch.tensor([[0.6, 0.1], [0.2, 0.7]])
print(cond)
print(cond > 0.5)
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[4, 5], [6, 7]])
c = torch.where(cond > 0.5, a, b)
print(c)

输出结果:

tensor([[0.6000, 0.1000],
        [0.2000, 0.7000]])
tensor([[1, 0],
        [0, 1]], dtype=torch.uint8)
tensor([[1, 5],
        [6, 4]])

2. gather

使用torch.gather(input,dim,index, out=None)对元素实现一个查表映射的操作:

# 4张图像的10种分类的概率值
prob = torch.randn(4, 10)
print(prob)
# 取概率最大的前3个的概率值及其索引
_, idx = prob.topk(3, dim=1)
print(idx)
label = torch.arange(10) + 100
# 用于将idx的0~9映射到100~109
print(label)
out = torch.gather(label.expand(4, 10), dim=1, index=idx.long())
print(label.expand(4, 10))
print(out)

输出结果:

tensor([[ 1.5271e+00, -1.0496e+00,  1.8080e+00,  7.7568e-01,  6.1931e-02,
         -4.2938e-01,  5.5196e-01,  3.8433e-02,  6.3620e-01,  1.0572e+00],
        [-6.8786e-01, -2.4641e-02,  3.4654e-02,  1.9941e-01, -4.9521e-01,
          6.6272e-01,  1.1453e-02, -1.3736e+00, -8.1644e-01, -1.2865e+00],
        [ 1.7232e-03, -6.3537e-01, -1.7067e+00, -7.8266e-01, -5.5783e-01,
          3.8720e+00, -3.1364e-01,  5.0548e-01, -5.2201e-01, -3.3244e-01],
        [ 2.8563e-01, -4.2893e-01,  5.2956e-01, -9.6540e-01, -3.5586e-02,
         -4.6095e-01, -4.5072e-01,  8.8575e-01, -1.0540e-01, -2.2548e-01]])
idx:  tensor([[2, 0, 9],
        [5, 3, 2],
        [5, 7, 0],
        [7, 2, 0]])
label[10]:  tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
label[4, 10]:  tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
out:  tensor([[102, 100, 109],
        [105, 103, 102],
        [105, 107, 100],
        [107, 102, 100]])

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洪流之源

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值