目录
1. MNIST数据集
使用MNIST手写数字数据集(可以下载下来,也可以使用tensorflow直接import)
数据集介绍:
train-images-idx3-ubyte.gz:训练集图片60,000张,每张大小28x28
train-labels-idx1-ubyte.gz:训练集标签60,000个
t10k-images-idx3-ubyte.gz:测试集图片10,000张,每张大小28x28
t10k-labels-idx1-ubyte.gz:测试集标签10,000个
2.使用tensorflow1.0实现
2.1 环境配置
Windows + cuda10.0 + cudnn+ Anaconda + python3.7 + tensorflow1.15(GPU)
2.2 获取数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
2.3 定义变量、模型构建函数
# 返回初始化权重
def weight_variable(shape):
inital = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(inital)
# 返回初始化偏置
def bias_variable(shape):
inital = tf.constant(0.1, shape=shape)
return tf.Variable(inital)
# 返回卷积层
# strides: [1, x_movement, y_movement, 1]
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides = [1,1,1,1], padding = 'SAME')
# 返回最大池化层
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding = 'SAME')
2.4 CNN模型构建
# conv1 layer
W_conv1 = weight_variable([5,5,1,32]) # patch: 5x5, in size: 1, out size: 32
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) # output size: 28x28x32
h_pool1 = max_pool_2x2(h_conv1) # output size: 14x14x32
# conv2 layer
W_conv2 = weight_variable([5,5,1,64]) # patch: 5x5, in size: 1, out size: 64
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) # output size: 14x14x64
h_pool2 = max_pool_2x2(h_conv2) # output size: 7x7x64
# FC1 layer
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64]) # [n,7,7,64] >> [n,7x7x64]
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# FC2 layer
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
2.5 定义损失函数及优化器
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)