torch中关于torch.max()和torch.min()函数的理解

  • 简介

在tensor类型的数据中,max和min函数常用来比较两个tensor数据的大小,或者取出tensor数据中的最大值。关于max函数和min函数的用法有以下几种场景:

对于tensorA和tensorB:

  1. torch.max(tensorA):返回tensor中的最大值。
  2. torch.mac(tensorA,dim):dim表示指定的维度,返回指定维度的最大数和对应下标
  3. torch.max(tensorA,tensorB):比较tensorA和tensorB相对较大的元素。
输入:
x = th.arange(0,16,1).view(4,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())
输出:
x:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([ 3,  7, 11, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,0):
 torch.return_types.max(
values=tensor([12, 13, 14, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,1)[0]:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[1]:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data.numpy():
 [3 3 3 3]
t.max(x,1)[1].data.numpy().squeeze():
 [3 3 3 3]
t.max(x,1)[0].data:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[0].data.numpy():
 [ 3  7 11 15]
t.max(x,1)[0].data.numpy().squeeze():
 [ 3  7 11 15]

上面代码中x是一个4*4的二阶张量:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

t.max(x)表示从x中取出最大的那个值:15

t.max(x,1)表示每行中取出最大的元素,并返回其下标,第一行为3,第二行为7,第三行为11,第四行为15.注意:t.max(x,1)中的1表示行,0表示列,因为此处的例子是二阶的张量,只有行和列。

所以t.max(x,1)表示每列中取出最大的元素,即[3,7,11,15]。

.data 只返回variable中的数据部分(去掉Variable containing:)。

.data.numpy() 把数据转化成numpy ndarry。

].data.numpy().squeeze() 把数据条目中维度为1 的删除掉。

torch.max(tensorA,tensorB) element-wise 比较tensorA 和tensorB 中的元素,返回较大的那个值。

那如果是3阶的张两呢?

下面我们将x的张量变成2*2*4,即2维的两行四列。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,0):
 torch.return_types.max(
values=tensor([[ 8,  9, 10, 11],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,1)[0]:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[1]:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data.numpy():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[1].data.numpy().squeeze():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[0].data:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[0].data.numpy():
 [[ 4  5  6  7]
 [12 13 14 15]]
t.max(x,1)[0].data.numpy().squeeze():
 [[ 4  5  6  7]
 [12 13 14 15]]

从输出结果看,

  1. t.max(x):依然是从x中取出最大的一个值。
  2. t.max(x,1):输出
    [[ 4,  5,  6,  7],
    [12, 13, 14, 15]],因为此时有两个维度,每个维度上有两行四列的数据,所以1表示行,即取出每个维度上每个行的最大值。[ 4, 5, 6, 7]是第一个维度中的最大值所在的行,[12, 13, 14, 15]是第二个维度中最大值所在的行。
  3. t.max(x,0):输出:
    torch.return_types.max(
    values=tensor([[ 8,  9, 10, 11],
            [12, 13, 14, 15]]),
    indices=tensor([[1, 1, 1, 1],
            [1, 1, 1, 1]]))此时,0表示是维度,x有两个维度,表示从这两个维度中取出最大值所在的维度。即[[ 8, 9, 10, 11], [12, 13, 14, 15]]。
    

假如我们想要取出每个维度中每一列中最大的元素该怎么办呢?我们只需要指定t.max(x,dim)中的dim=2就可以取出列向量了。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,2):\n',t.max(x,2))
print('t.max(x,2)[0]:\n',t.max(x,2))
print('t.max(x,1)[0][0]:\n',t.max(x,2)[0][0])


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,2):
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,2)[0]:
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,1)[0][0]:
 tensor([3, 7])

 

### PyTorch `torch.histc` 与 `torch.histogram` 的功能及用法 #### 功能对比 `torch.histc` `torch.histogram` 都用于计算输入张量的直方图,但它们的设计目标参数有所不同。 - **`torch.histc`**: 这是一个较早引入的方法,主要用于返回固定范围内的直方图分布。它的核心特点是通过指定最小值 (`min`) 最大值 (`max`) 来定义区间边界,并自动划分成均匀宽度的子区间[^1]。如果某些数据点超出了 `[min, max]` 范围,则这些数据会被忽略不计[^4]。 - **`torch.histogram`**: 它是更灵活的一种方法,在 PyTorch 中作为 NumPy 接口的一部分提供支持。除了可以接受固定的上下限外,还可以传入自定义的 bin 边界数组来创建非均分间隔的直方图[^6]。这意味着用户能够更加精细地控制每个区间的大小以及位置安排。 #### 使用场景分析 当需要快速构建具有预设数量等宽箱子的标准统计图表时可以选择使用 `torch.histc`;而面对复杂需求比如不同长度或者形状各异的数据分区情况则更适合采用 `torch.histogram`. 以下是两个函数的具体差异总结: | 特性 | torch.histc | torch.histogram | |---------------------|--------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------| | 输入类型 | 单一数值型张量 | 支持多种类型的张量 | | Bin 设置 | 基于整数 bins 数目设定 | 可以给定具体的边界的列表 | | 处理越界值 | 自动丢弃超出[min,max]之外的所有样本 | 不会抛弃任何数据项 | | 输出形式 | 返回单一维度的结果向量 | 同样给出频率计数的同时也会附带对应的bin边缘信息 | #### 示例代码展示 下面分别给出了两者的简单应用实例以便直观感受其操作过程: 对于 `torch.histc`, 下面的例子展示了如何利用该命令生成一个简单的频次表: ```python import torch data = torch.tensor([1., 1, 2, 2, 2, 5, 8]) area_intersect = torch.histc(data, bins=5, min=0, max=9) print(area_intersect) # tensor([2., 3., 1., 0., 1.]) ``` 而对于更为通用的 `torch.histogram`, 则可以通过如下方式实现相同目的同时也获取到更多细节: ```python values = torch.tensor([1., 1, 2, 2, 2, 5, 8]) edges = torch.arange(0, 10).float() histogram_values, edges_out = torch.histogram(values, edges) print(histogram_values) # tensor([2., 3., 0., 0., 0., 1., 0., 0., 1.]) print(edges_out) # tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) ``` 上述例子清晰表明了两者之间的主要差别在于后者提供了额外关于箱线的信息即edge values. ### 结论 综上所述,尽管二者都旨在解决相似的问题——绘制直方图并量化离散化后的概率密度估计等问题,但由于各自侧重点有所偏移所以适用场合也存在显著区别。因此在实际项目开发过程中应根据具体业务逻辑合理选用合适的工具从而达到最佳效果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值