#encoding:utf-8
import numpy as np
import tensorflow as tf
x_image = tf.placeholder(tf.float32, shape = [4,4])
x = tf.reshape(x_image, [1, 4, 4, 1 ])
ksize = [1, 2, 2, 1]
strides = [1, 2, 2 ,1]
padding = 'VALID'
#max_pooling
#x:池化操作的输入
#ksize:池化窗口的大小
#strides:窗口在每一个维度上滑动的步长,一般是[1, stride, stride, 1]
#padding:"VALID"or"SAME"
y = tf.nn.max_pool(x, ksize, strides, padding)
x_data = np.array([
[4, 3, 1, 8],
[7, 2, 6, 3],
[2, 0, 1, 1],
[3, 4, 2, 5]
])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x = sess.run(x, feed_dict = {x_image : x_data})
y = sess.run(y, feed_dict = {x_image : x_data})
print "The shape of x:", x.shape
print x.reshape(4, 4)
print ""
print "The shape pf y:", y.shape
print y.reshape(2,2)
print ""
输出:
The shape of x: (1, 4, 4, 1)
[[ 4. 3. 1. 8.]
[ 7. 2. 6. 3.]
[ 2. 0. 1. 1.]
[ 3. 4. 2. 5.]]
The shape pf y: (1, 2, 2, 1)
[[ 7. 8.]
[ 4. 5.]]