torch.max用法:
#coding=utf-8
import torch
from torch.autograd import Variable
#返回输入tensor中所有元素的最大值
a = torch.randn(1, 3)
print a
print torch.max(a)
print '--' * 10
#按维度dim 返回最大值 torch.max)(a,0) 返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
b = torch.randn(3,3)
b = Variable(b)
print b
print torch.max(b,0)
print '--' * 10
#torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
print torch.max(b,1)
print '----' * 10
print torch.max(b,1)[0] #只返回最大值
print '----' * 10
print torch.max(b,1)[1] #只返回最大值所在的索引
结果:
0.8280 -0.3046 0.4422
[torch.FloatTensor of size 1x3]
0.828029453754
--------------------
Variable containing:
-0.2242 0.2954 1.3393
0.2511 0.7484 1.4388
-0.1107 0.0453 2.2737
[torch.FloatTensor of size 3x3]
(Variable containing:
0.2511 0.7484 2.2737
[torch.FloatTensor of size 1x3]
, Variable containing:
1 1 2
[torch.LongTensor of size 1x3]
)
--------------------
(Variable containing:
1.3393
1.4388
2.2737
[torch.FloatTensor of size 3x1]
, Variable containing:
2
2
2
[torch.LongTensor of size 3x1]
)
----------------------------------------
Variable containing:
1.3393
1.4388
2.2737
[torch.FloatTensor of size 3x1]
----------------------------------------
Variable containing:
2
2
2
[torch.LongTensor of size 3x1]