TensorFlow是目前深度学习最流行的框架,很有学习的必要,下面我们就来实际动手,使用TensorFlow搭建一个简单的CNN,来对经典的mnist数据集进行数字识别。
step 0 导入TensorFlow
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data
step 1 加载数据集mnist
声明两个placeholder,用于存储神经网络的输入,输入包括image和label。这里加载的image是(784,)的shape。
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
x = tf.placeholder(tf.float32,[None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
step 2 定义weights和bias
为了使代码整洁,这里把weight和bias的初始化封装成函数。
#----Weight Initialization---#
#One should generally initialize weights with a small amount of noise for symmetry breaking, and to prevent 0 gradients
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
step 3 定义卷积层和maxpooling
同样,为了代码的整洁,将卷积层和maxpooling封装起来。padding=‘SAME’表示使用padding,