torch中关于torch.max()和torch.min()函数的理解

  • 简介

在tensor类型的数据中,max和min函数常用来比较两个tensor数据的大小,或者取出tensor数据中的最大值。关于max函数和min函数的用法有以下几种场景:

对于tensorA和tensorB:

  1. torch.max(tensorA):返回tensor中的最大值。
  2. torch.mac(tensorA,dim):dim表示指定的维度,返回指定维度的最大数和对应下标
  3. torch.max(tensorA,tensorB):比较tensorA和tensorB相对较大的元素。
输入:
x = th.arange(0,16,1).view(4,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())
输出:
x:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([ 3,  7, 11, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,0):
 torch.return_types.max(
values=tensor([12, 13, 14, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,1)[0]:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[1]:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data.numpy():
 [3 3 3 3]
t.max(x,1)[1].data.numpy().squeeze():
 [3 3 3 3]
t.max(x,1)[0].data:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[0].data.numpy():
 [ 3  7 11 15]
t.max(x,1)[0].data.numpy().squeeze():
 [ 3  7 11 15]

上面代码中x是一个4*4的二阶张量:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

t.max(x)表示从x中取出最大的那个值:15

t.max(x,1)表示每行中取出最大的元素,并返回其下标,第一行为3,第二行为7,第三行为11,第四行为15.注意:t.max(x,1)中的1表示行,0表示列,因为此处的例子是二阶的张量,只有行和列。

所以t.max(x,1)表示每列中取出最大的元素,即[3,7,11,15]。

.data 只返回variable中的数据部分(去掉Variable containing:)。

.data.numpy() 把数据转化成numpy ndarry。

].data.numpy().squeeze() 把数据条目中维度为1 的删除掉。

torch.max(tensorA,tensorB) element-wise 比较tensorA 和tensorB 中的元素,返回较大的那个值。

那如果是3阶的张两呢?

下面我们将x的张量变成2*2*4,即2维的两行四列。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,0):
 torch.return_types.max(
values=tensor([[ 8,  9, 10, 11],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,1)[0]:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[1]:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data.numpy():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[1].data.numpy().squeeze():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[0].data:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[0].data.numpy():
 [[ 4  5  6  7]
 [12 13 14 15]]
t.max(x,1)[0].data.numpy().squeeze():
 [[ 4  5  6  7]
 [12 13 14 15]]

从输出结果看,

  1. t.max(x):依然是从x中取出最大的一个值。
  2. t.max(x,1):输出
    [[ 4,  5,  6,  7],
    [12, 13, 14, 15]],因为此时有两个维度,每个维度上有两行四列的数据,所以1表示行,即取出每个维度上每个行的最大值。[ 4, 5, 6, 7]是第一个维度中的最大值所在的行,[12, 13, 14, 15]是第二个维度中最大值所在的行。
  3. t.max(x,0):输出:
    torch.return_types.max(
    values=tensor([[ 8,  9, 10, 11],
            [12, 13, 14, 15]]),
    indices=tensor([[1, 1, 1, 1],
            [1, 1, 1, 1]]))此时,0表示是维度,x有两个维度,表示从这两个维度中取出最大值所在的维度。即[[ 8, 9, 10, 11], [12, 13, 14, 15]]。
    

假如我们想要取出每个维度中每一列中最大的元素该怎么办呢?我们只需要指定t.max(x,dim)中的dim=2就可以取出列向量了。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,2):\n',t.max(x,2))
print('t.max(x,2)[0]:\n',t.max(x,2))
print('t.max(x,1)[0][0]:\n',t.max(x,2)[0][0])


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,2):
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,2)[0]:
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,1)[0][0]:
 tensor([3, 7])

 

  • 9
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值