探究 torch.max() 中 keepdim 参数的影响


torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

  • input (Tensor) – the input Tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool) – whether the output Tensors have dim retained or not
  • out (tuple, optional) – the result tuple of two output Tensors (max, max_indices)

以上是官方文档内对torch.max()参数的解释,详见官方文档

1. 二维

import torch
import numpy as np
testTensor = torch.randn(3,4)
print(testTensor)
print(testTensor.shape)
tensor([[-0.9576, -0.6422,  0.3878,  0.9311],
        [-2.2633, -1.1474,  0.8491,  0.3827],
        [-0.3658,  1.9686, -0.5864,  1.3481]])
torch.Size([3, 4])

1.1. keepdim=False

  • dim对应维度被完全消除
a = torch.max(testTensor,0,keepdim=False)
b = torch.max(testTensor,1,keepdim=False)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
torch.return_types.max(
values=tensor([-0.3658,  1.9686,  0.8491,  1.3481]),
indices=tensor([2, 2, 1, 2]))
torch.Size([4])
------------------------------------
torch.return_types.max(
values=tensor([0.9311, 0.8491, 1.9686]),
indices=tensor([3, 2, 1]))
torch.Size([3])

1.2. keepdim=True

  • dim对应维度被变成1
a = torch.max(testTensor,0,keepdim=True)
b = torch.max(testTensor,1,keepdim=True)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
torch.return_types.max(
values=tensor([[-0.3658,  1.9686,  0.8491,  1.3481]]),
indices=tensor([[2, 2, 1, 2]]))
torch.Size([1, 4])
------------------------------------
torch.return_types.max(
values=tensor([[0.9311],
        [0.8491],
        [1.9686]]),
indices=tensor([[3],
        [2],
        [1]]))
torch.Size([3, 1])

2. 三维

testTensor2 = torch.randn(3,4,5)
print(testTensor2)
print(testTensor2.shape)
tensor([[[ 1.2689e+00,  2.4801e-01, -6.1024e-01, -5.5274e-01,  1.3316e+00],
         [ 2.0961e-01,  5.7410e-01,  7.9837e-01,  1.9241e-01, -1.9209e-01],
         [ 8.3712e-01,  3.4982e-01,  1.0416e+00,  4.7590e-01,  6.5989e-01],
         [-5.6227e-01,  7.9599e-01,  1.2658e+00,  2.0524e+00,  3.1579e-01]],

        [[-5.2458e-01,  1.3057e+00,  5.0561e-01,  3.9769e-01,  1.3417e+00],
         [ 2.0886e-01,  5.1901e-01, -6.7622e-01, -6.7071e-01,  1.3424e+00],
         [ 9.2434e-02,  3.2769e-01,  1.1805e+00, -8.0025e-01, -3.0728e-01],
         [-3.6745e-01,  9.9187e-01,  9.2449e-01, -3.4050e-01,  1.1566e+00]],

        [[ 7.0320e-02,  2.8124e-01,  9.1458e-01, -2.5570e-01, -1.6569e+00],
         [ 5.7314e-03,  1.1809e+00, -4.1275e-01, -4.0657e-02,  1.3055e-03],
         [ 3.0750e-01, -8.1870e-02, -7.1268e-01, -5.0761e-01,  2.0539e+00],
         [-4.4742e-01,  1.0000e-01, -1.0436e-01,  2.6575e-01,  6.9229e-01]]])
torch.Size([3, 4, 5])

2.1. keepdim=False

  • dim对应维度被完全消除
a = torch.max(testTensor2,0,keepdim=False)
b = torch.max(testTensor2,1,keepdim=False)
c = torch.max(testTensor2,2,keepdim=False)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
print('------------------------------------')
print(c)
print(c[0].shape)
torch.return_types.max(
values=tensor([[ 1.2689,  1.3057,  0.9146,  0.3977,  1.3417],
        [ 0.2096,  1.1809,  0.7984,  0.1924,  1.3424],
        [ 0.8371,  0.3498,  1.1805,  0.4759,  2.0539],
        [-0.3675,  0.9919,  1.2658,  2.0524,  1.1566]]),
indices=tensor([[0, 1, 2, 1, 1],
        [0, 2, 0, 0, 1],
        [0, 0, 1, 0, 2],
        [1, 1, 0, 0, 1]]))
torch.Size([4, 5])
------------------------------------
torch.return_types.max(
values=tensor([[1.2689, 0.7960, 1.2658, 2.0524, 1.3316],
        [0.2089, 1.3057, 1.1805, 0.3977, 1.3424],
        [0.3075, 1.1809, 0.9146, 0.2658, 2.0539]]),
indices=tensor([[0, 3, 3, 3, 0],
        [1, 0, 2, 0, 1],
        [2, 1, 0, 3, 2]]))
torch.Size([3, 5])
------------------------------------
torch.return_types.max(
values=tensor([[1.3316, 0.7984, 1.0416, 2.0524],
        [1.3417, 1.3424, 1.1805, 1.1566],
        [0.9146, 1.1809, 2.0539, 0.6923]]),
indices=tensor([[4, 2, 2, 3],
        [4, 4, 2, 4],
        [2, 1, 4, 4]]))
torch.Size([3, 4])

2.2. keepdim=True

  • dim对应维度被变成1
a = torch.max(testTensor2,0,keepdim=True)
b = torch.max(testTensor2,1,keepdim=True)
c = torch.max(testTensor2,2,keepdim=True)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
print('------------------------------------')
print(c)
print(c[0].shape)
torch.return_types.max(
values=tensor([[[ 1.2689,  1.3057,  0.9146,  0.3977,  1.3417],
         [ 0.2096,  1.1809,  0.7984,  0.1924,  1.3424],
         [ 0.8371,  0.3498,  1.1805,  0.4759,  2.0539],
         [-0.3675,  0.9919,  1.2658,  2.0524,  1.1566]]]),
indices=tensor([[[0, 1, 2, 1, 1],
         [0, 2, 0, 0, 1],
         [0, 0, 1, 0, 2],
         [1, 1, 0, 0, 1]]]))
torch.Size([1, 4, 5])
------------------------------------
torch.return_types.max(
values=tensor([[[1.2689, 0.7960, 1.2658, 2.0524, 1.3316]],

        [[0.2089, 1.3057, 1.1805, 0.3977, 1.3424]],

        [[0.3075, 1.1809, 0.9146, 0.2658, 2.0539]]]),
indices=tensor([[[0, 3, 3, 3, 0]],

        [[1, 0, 2, 0, 1]],

        [[2, 1, 0, 3, 2]]]))
torch.Size([3, 1, 5])
------------------------------------
torch.return_types.max(
values=tensor([[[1.3316],
         [0.7984],
         [1.0416],
         [2.0524]],

        [[1.3417],
         [1.3424],
         [1.1805],
         [1.1566]],

        [[0.9146],
         [1.1809],
         [2.0539],
         [0.6923]]]),
indices=tensor([[[4],
         [2],
         [2],
         [3]],

        [[4],
         [4],
         [2],
         [4]],

        [[2],
         [1],
         [4],
         [4]]]))
torch.Size([3, 4, 1])
  • 8
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值