两个高阶操作

where

torch.where(condition, x y) --->Tensor

对条件判断,满足的选择填充对应X位置元素,否则Y位置

a = torch.rand(2,2)
b = torch.tensor([[0,0],
     [0,0]])
c = torch.tensor([[1,1],[1,1]])
d = torch.where(a>0.5,b,c)    #对条件判断,满足的获取b位置元素
print(a)
print(d)
tensor([[0.7284, 0.9584],
        [0.0740, 0.3946]])    #a

tensor([[0, 0],
        [1, 1]])

gather

torch.gather( input, dim , index, out = None) ---->Tensor

收集操作: 输入一个tensor,指定维度,索引,返回一个和索引shape一样的tensor

mn = torch.randn(4,10)
idx = mn.topk(k=3,dim=1)
idx = idx[1]    #索引

print(mn),print(idx)

label = torch.arange(10)+100
p = torch.gather(label.expand(4,10),dim=1,index=idx.long())     #4行数从100-109
# print(idx.long())    #转换为longtensor数据格式                 #按照索引取数据,
print(p)
tensor([[ 0.9868,  1.3073, -0.3827, -1.0585,  0.0805, -1.3429,  0.6678, -0.3388,
         -0.4304,  0.2057],
        [ 0.1931, -0.6879, -0.1194,  1.4844, -0.4510, -0.4621, -0.9452,  1.1003,
         -1.0377, -1.0391],
        [ 0.6447, -0.1251, -0.7113,  0.8599, -0.3897, -0.3618, -0.4018,  0.8156,
          1.8524, -0.9277],
        [-0.5331,  0.1754,  0.2532,  0.3705,  0.6344,  1.5306, -1.2829, -0.2778,
         -0.4927, -0.3294]])

tensor([[1, 0, 6],    #前三大的索引
        [3, 7, 0],
        [8, 3, 7],
        [5, 4, 3]])

tensor([[101, 100, 106],
        [103, 107, 100],
        [108, 103, 107],
        [105, 104, 103]])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

onlywishes

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值