pytorch学习——torch.max() 和 torch.gather()

 1.torch.max()

b=torch.randn(3,4)
print("b:\n",b)
print('b.max(0):',b.max(0))
print('b.max(1):',b.max(1))
print('b.max(1)[0]:',b.max(1)[0])
print('b.max(1)[1]:',b.max(1)[1])
print('b.max(0)[0]:',b.max(0)[0])
print('b.max(0)[1]:',b.max(0)[1])
print('torch.max(b,0)[0]:',torch.max(b,0)[0])
print('torch.max(b,0)[1]:',torch.max(b,0)[1])
#torch.max(tensor,a)[b]中
#tensor代表想要进行操作的张量
#(a)表示想要进行操作的维度,如本例子为3*4的张量,维度为2。但是维度是从0开始索引的。
#例如对一个3*4*5的三维张量,a=0表示输出一个4*5的二维张量,进行比较的都是位于第一维的数据,详见下面的示例
#[b]表示想要输出的东西,0表示输出所得的最大值组成的张量,1表示输出所得的最大值对应的索引
b=torch.randn(3,4,5)
print("b:\n",b)
print('b.max(0)[0]:',b.max(0)[0])
print('b.max(0)[1]:',b.max(0)[1])
print('b.max(1)[0]:',b.max(1)[0])
print('b.max(1)[1]:',b.max(1)[1])
print('b.max(2)[0]:',b.max(2)[0])
print('b.max(2)[1]:',b.max(2)[1])

结果:

b:
 tensor([[ 0.8548,  0.8018, -1.8448, -1.9834],
        [-0.1640,  0.1191, -0.1988,  1.4577],
        [ 0.6882,  1.3879, -0.9425,  0.5111]])
b.max(0): torch.return_types.max(
values=tensor([ 0.8548,  1.3879, -0.1988,  1.4577]),
indices=tensor([0, 2, 1, 1]))
b.max(1): torch.return_types.max(
values=tensor([0.8548, 1.4577, 1.3879]),
indices=tensor([0, 3, 1]))
b.max(1)[0]: tensor([0.8548, 1.4577, 1.3879])
b.max(1)[1]: tensor([0, 3, 1])
b.max(0)[0]: tensor([ 0.8548,  1.3879, -0.1988,  1.4577])
b.max(0)[1]: tensor([0, 2, 1, 1])
torch.max(b,0)[0]: tensor([ 0.8548,  1.3879, -0.1988,  1.4577])
torch.max(b,0)[1]: tensor([0, 2, 1, 1])
b:
 tensor([[[-1.0886, -0.2318,  0.5687,  0.6992,  0.7993],
         [-1.6488, -0.4074,  0.7266, -1.5632, -0.2511],
         [ 0.2802, -0.1441, -0.0184, -1.2758,  1.0951],
         [ 0.1700, -1.6584, -1.8804,  1.1791, -0.3551]],

        [[-0.6114, -0.8189,  0.3983, -0.3744,  1.0545],
         [ 1.1177,  0.5962,  0.5073,  0.3863,  0.9423],
         [ 0.0200, -1.0322,  0.1329, -0.3569,  0.3295],
         [-1.0800, -0.5290, -1.7301,  0.0407,  0.4790]],

        [[ 0.6548, -0.6043,  1.7130, -0.9530,  0.7656],
         [-0.1233, -1.5174,  0.8852, -0.2953, -0.1956],
         [-0.6458, -1.4962, -1.4248,  0.6494,  0.8585],
         [-1.9865, -2.4318,  1.4741, -0.9116, -0.2791]]])
b.max(0)[0]: tensor([[ 0.6548, -0.2318,  1.7130,  0.6992,  1.0545],
        [ 1.1177,  0.5962,  0.8852,  0.3863,  0.9423],
        [ 0.2802, -0.1441,  0.1329,  0.6494,  1.0951],
        [ 0.1700, -0.5290,  1.4741,  1.1791,  0.4790]])
b.max(0)[1]: tensor([[2, 0, 2, 0, 1],
        [1, 1, 2, 1, 1],
        [0, 0, 1, 2, 0],
        [0, 1, 2, 0, 1]])
b.max(1)[0]: tensor([[ 0.2802, -0.1441,  0.7266,  1.1791,  1.0951],
        [ 1.1177,  0.5962,  0.5073,  0.3863,  1.0545],
        [ 0.6548, -0.6043,  1.7130,  0.6494,  0.8585]])
b.max(1)[1]: tensor([[2, 2, 1, 3, 2],
        [1, 1, 1, 1, 0],
        [0, 0, 0, 2, 2]])
b.max(2)[0]: tensor([[0.7993, 0.7266, 1.0951, 1.1791],
        [1.0545, 1.1177, 0.3295, 0.4790],
        [1.7130, 0.8852, 0.8585, 1.4741]])
b.max(2)[1]: tensor([[4, 2, 4, 3],
        [4, 0, 4, 4],
        [2, 2, 4, 2]])

 2.torch.gather()

在1中提到,[b]表示想要输出的东西,0表示输出所得的最大值组成的张量,1表示输出所得的最大值对应的索引张量。

而torch.gather(input, dim, index,out=None)中,input是想要抓取数据的原始张量,dim是要索引的维度,index就对应想要抓取的元素的索引(与上文提到的b=1时的输出是一个东西)。

给出示例:

b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))

结果:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 3.]])

 

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值