在Pytorch中,max函数包括torch
中顶级函数torch.max
和Tensor对象的max
函数,且均实现了overload
(函数重载),以泛化其功能。其常见的使用方法包括:
可用于Tensor对象内部元素的极值获取,或者两个Tensor对象的逐元素对比。
1. 张量内部元素的极值
import torch
a = torch.arange(0, 6).reshape(2,3) # tensor([[0, 1, 2][3, 4, 5]])
# 所有元素的最大值
torch.max(a) # 方法一,tensor(5)
a.max() # 方法二,tensor(5)
# 沿某个dim的最大值
torch.max(a, dim=1) # 方法一,(tensor([2, 5]), tensor([2, 2])), 第二个元素即为torch.argmax(a, dim=1)
a.max(dim=1) # 方法二
2. 逐元素对比两个张量
该操作支持广播
import torch
a = torch.arange(0, 6).reshape(2,3) # tensor([[0, 1, 2][3, 4, 5]])
torch.max(a, torch.tensor(2)) # 方法一:tensor([[2, 2, 2], [3, 4, 5]])
a.max(torch.tensor(2)) # 方法二
一个典型应用就是relu
激活函数:
def relu(t):
return torch.max(t, torch.zeros_like(t))