torch.gether()用法

torch.gether()用法

代码展示

b = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(b)
index_1 = torch.tensor([[0, 0, 0], [1, 1, 1]]).repeat(5, 1)
index_2 = torch.tensor([[0, 1, 1, 2, 2, 2, 1], [1, 1, 1, 0, 0, 2, 2]])
print(torch.gather(b, dim=1, index=index_2))
print(torch.gather(b, dim=0, index=index_1)) 

结果

tensor([[1, 2, 3],
        [4, 5, 6]])
tensor([[1, 2, 2, 3, 3, 3, 2],
        [5, 5, 5, 4, 4, 6, 6]])
tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])

分析用法:

torch.gether()有三个参数, 第一个为src即要被索引的张量; 第二个dim,意思在哪个维度上去索引;第三个index,索引张量。

如果想成功的用gather():想在哪个维度上索引,哪个维度就可以灵活变化,可以在1-n上去边,如上面例子

torch.gather(b, dim=1, index=index_2),dim=1,所以列数可以灵活变化,原张量的形状是(2, 3),索引形状可以是(2,n),其中n>=1,形状中对应元素的取值必须小于原张量列的最大值,如上述索引的值最大为2。

torch.gather(b, dim=0, index=index_1)的原理亦是如此,索引的最大值为1

假如一个张量的形状为src.shape(a, b, c):

当dim=0, index.shape(n,b,c), 其中n>=1即可

当dim=1, index.shape(a, n ,c),其中n>=1即可

当dim=2, index.shape(a,b,n) ,其中n>=1即可

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值