1. where
使用C=torch.where(condition,A,B)其中A,B,C,condition是shape相同的Tensor,C中的某些元素来自A,某些元素来自B,这由condition中对应位置的元素是1还是0来决定。如果condition对应位置元素是1,则C中的该位置的元素来自A中的该位置的元素,如果condition对应位置元素是0,则C中的该位置的元素来自B中的该位置的元素。
示例代码:
import torch
cond = torch.tensor([[0.6, 0.1], [0.2, 0.7]])
print(cond)
print(cond > 0.5)
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[4, 5], [6, 7]])
c = torch.where(cond > 0.5, a, b)
print(c)
输出结果:
tensor([[0.6000, 0.1000],
[0.2000, 0.7000]])
tensor([[1, 0],
[0, 1]], dtype=torch.uint8)
tensor([[1, 5],
[6, 4]])
2. gather
使用torch.gather(input,dim,index, out=None)对元素实现一个查表映射的操作:
# 4张图像的10种分类的概率值
prob = torch.randn(4, 10)
print(prob)
# 取概率最大的前3个的概率值及其索引
_, idx = prob.topk(3, dim=1)
print(idx)
label = torch.arange(10) + 100
# 用于将idx的0~9映射到100~109
print(label)
out = torch.gather(label.expand(4, 10), dim=1, index=idx.long())
print(label.expand(4, 10))
print(out)
输出结果:
tensor([[ 1.5271e+00, -1.0496e+00, 1.8080e+00, 7.7568e-01, 6.1931e-02,
-4.2938e-01, 5.5196e-01, 3.8433e-02, 6.3620e-01, 1.0572e+00],
[-6.8786e-01, -2.4641e-02, 3.4654e-02, 1.9941e-01, -4.9521e-01,
6.6272e-01, 1.1453e-02, -1.3736e+00, -8.1644e-01, -1.2865e+00],
[ 1.7232e-03, -6.3537e-01, -1.7067e+00, -7.8266e-01, -5.5783e-01,
3.8720e+00, -3.1364e-01, 5.0548e-01, -5.2201e-01, -3.3244e-01],
[ 2.8563e-01, -4.2893e-01, 5.2956e-01, -9.6540e-01, -3.5586e-02,
-4.6095e-01, -4.5072e-01, 8.8575e-01, -1.0540e-01, -2.2548e-01]])
idx: tensor([[2, 0, 9],
[5, 3, 2],
[5, 7, 0],
[7, 2, 0]])
label[10]: tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
label[4, 10]: tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
out: tensor([[102, 100, 109],
[105, 103, 102],
[105, 107, 100],
[107, 102, 100]])

328

被折叠的 条评论
为什么被折叠?



