Pytorch学习-torch.max和min深度解析
max的使用 min同理
对于tensorA和tensorB:
1)torch.max(tensorA) 返回tensor中的最大值
2)torch.max(tensorA,dim) 返回指定维度的最大数和对应下标
3)torch.max(tensorA,tensorB) 比较tensorA和tensorB相对较大的元素
dim参数理解
搞清楚dim参数
第0维是行,第1维是列!!!
结论:
1)dim=0 查找每列的最大值,返回行下标索引
2)dim=1 查找每行的最大值,返回列下标索引
3)不添加dim参数,返回所有值中的最大值,且无索引
二维张量使用max()
t=torch.randn(2,3)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))
结果:
tensor([[ 0.0231, 0.2109, -1.6104],
[-0.5777, -1.3870, -0.9925]])
-------max dim=0 -------
torch.return_types.max(
values=tensor([ 0.0231, 0.2109, -0.9925]),
indices=tensor([0, 0, 1]))
-------max dim=1 -------
torch.return_types.max(
values=tensor([ 0.2109, -0.5777]),
indices=tensor([1, 0]))
**???疑问:**为什么0维是行,但是max时返回是列中的最大值呢?
理解:!!在其他维度均确定的情况下,比较所有dim维对应的数据,找到其中的最大值,并返回索引。
比如:
dim=0时 除了[0]维 还有[1]两个维度
第一列 遍历两行得到 [0][0] 和 [1][0] max为0.0231
第二列 遍历两行得到 [0][1] 和 [0][2] max为0.2109
第三列 遍历两行得到 [1][1] 和 [1][2] max为-0.9925
三维张量使用max()
第0维顺着层,第1维顺着行,第2维度顺着列
t = torch.randn(2,2,2)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))
print("-------max dim=2 -------")
print(torch.max(t,dim=2))
结果:
tensor([[[-1.6519, -0.3087],
[-0.6982, 0.4515]],
[[-0.4648, 0.8958],
[-1.4150, -1.4633]]])
-------max dim=0 ------- [[-0.4648, 0.8958],[-0.6982, 0.4515]] [[1,1],[0,0]] 列确定 比较行
torch.return_types.max(
values=tensor([[-0.4648, 0.8958],
[-0.6982, 0.4515]]),
indices=tensor([[1, 1],
[0, 0]]))
-------max dim=1 ------- [[-0.6982, 0.4515],[-0.4648,0.8958]] [[],[]] ([0][0][0],[0][1][0]),([0][0][1],[0][1][1]),([1][0][0],[1][1][0]),([1][0][1],[1][1][1])
torch.return_types.max(
values=tensor([[-0.6982, 0.4515],
[-0.4648, 0.8958]]),
indices=tensor([[1, 1],
[0, 0]]))
-------max dim=2 ------- [0][0]_,[0][1]_,[1][0]_,[1][1]_ ([0][0][0],[0][0][1]) ([0][1][0],[0][1][1]) ([1][0][0],[1][0][1]) ([1][1][0],[1][1][1])
torch.return_types.max(
values=tensor([[-0.3087, 0.4515],
[ 0.8958, -1.4150]]),
indices=tensor([[1, 1],
[1, 0]]))