1. 函数语法格式
torch.nn.MaxPool2d(
kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False
)
2. 参数解释
kernel_size
(int or tuple)【必选】:max pooling 的窗口大小,当最大池化窗口是方形的时候,只需要一个整数边长即可;最大池化窗口不是方形时,要输入一个元组表 高和宽。stride
(int or tuple, optional)【可选】:max pooling 的窗口移动的步长。默认值是 kernel_sizepadding
(int or tuple, optional)【可选】:输入的每一条边补充0的层数dilation
(int or tuple, optional)【可选】:一个控制窗口中元素步幅的参数return_indices
(bool)【可选】:如果等于 True,会返回输出最大值的序号,对于上采样操作会有帮助ceil_mode
(bool)【可选】:如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作
⭐ dilation 说明
如果我们设置的 dilation=0 的话,效果如图:蓝色为输入,绿色为输出,最大池化窗口为3 × 3
如果设置的是dilation=1,那么效果如图:蓝色为输入,绿色为输出,最大池化窗口卷积核仍为 3 × 3 。
3. 尺寸关系
输入可以为:
(
N
,
C
i
n
,
H
i
n
,
W
i
n
)
(N,C_{in},H_{in},W_{in})
(N,Cin,Hin,Win) 或
(
C
i
n
,
H
i
n
,
W
i
n
)
(C_{in},H_{in},W_{in})
(Cin,Hin,Win)
输出可以为:
(
N
,
C
o
u
t
,
H
o
u
t
,
W
o
u
t
)
(N,C_{out},H_{out},W_{out})
(N,Cout,Hout,Wout) 或
(
C
o
u
t
,
H
o
u
t
,
W
o
u
t
)
(C_{out},H_{out},W_{out})
(Cout,Hout,Wout)
它们之间的关系为:
H
o
u
t
=
⌊
H
i
n
+
2
×
p
a
d
d
i
n
g
[
0
]
−
d
i
l
a
t
i
o
n
[
0
]
×
(
k
e
r
n
e
l
_
s
i
z
e
[
0
]
−
1
)
−
1
s
t
r
i
d
e
[
0
]
+
1
⌋
H_{out}=\left\lfloor\frac{H_{in}+2 \times padding[0]-dilation[0] \times(kernel\_size[0]-1)-1}{ stride [0]}+1\right\rfloor
Hout=⌊stride[0]Hin+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1+1⌋
W o u t = ⌊ W i n + 2 × p a d d i n g [ 1 ] − d i l a t i o n [ 1 ] × ( k e r n e l _ s i z e [ 1 ] − 1 ) − 1 s t r i d e [ 1 ] + 1 ⌋ W_{out}=\left\lfloor\frac{W_{in}+2 \times padding[1]-dilation[1] \times(kernel\_size[1]-1)-1}{ stride [1]}+1\right\rfloor Wout=⌊stride[1]Win+2×padding[1]−dilation[1]×(kernel_size[1]−1)−1+1⌋
4. 使用案例
# pool of square window of size=3, stride=2
m = nn.MaxPool2d(3, stride=2)
# pool of non-square window
m = nn.MaxPool2d((3, 2), stride=(2, 1))
input = torch.randn(20, 16, 50, 32)
output = m(input)
5. nn.functional.max_pool2d
⭐ 区别
torch.nn.MaxPool2d
和 torch.nn.functional.max_pool2d
,在 pytorch 构建模型中,都可以作为最大池化层的引入,但前者为类模块,后者为函数,在使用上存在不同。
⭐ 使用
torch.nn.functional.max_pool2d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False
)