本文以tf1.13版本中的tf.nn.pool函数为例,介绍N-D输入的池化操作,函数输入参数如下。
tf.nn.pool(
input, #N+2维的tensor,若date_format格式不以NC开头,则input形状[batch_size]+ N-D_input_shape + [num_channels],以NC开头时,形状为[batch_size, num_channels] + N-D_input_shape
window_shape, #N个大于等于1的序列
pooling_type, #两种类型:"AVG" or "MAX"
padding, #同conv的类型:"SAME" or "VALID"
dilation_rate=None, #意义同input stride or dilation,若其值大于1,则strides的所有值为1
strides=None, #N个大于等于1的序列,若其值大于1,则dilation_rate的所有值为1
name=None, #OP名称
data_format=None #N=1时为NWC(默认)或NCW,N=2时为NHWC(默认)或NCHW,N=3时为NDHWC(默认)或
NCDHW
)
当data_format格式以NC开头时,该函数的处理逻