还是听的莫烦大大的视频课程,顺便把代码码下来并且跑了一下,又加了一些注释。对于小白来说,要想完全搞懂这些代码,需要理解CNN的结构以及掌握tensorflow的相关知识。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
###导入数据集###
MNIST_data = 'C:/Users/think/Desktop/ML/mnist'
mnist = input_data.read_data_sets(MNIST_data, one_hot=True)
###定义accuracy函数###
def compute_accuracy(v_xs,v_ys):
global prediction
y_pre = sess.run(prediction, feed_dict={
xs:v_xs, keep_prob:1})
correct_prediction = tf.equal(tf.argmax(y_pre, 1),tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={
xs:v_xs,ys:v_ys,keep_prob:1})
return result
###定义卷积层的weight和bias###
#使用tf.truncated_normal产生随机变量初始化
def weight_variable(shape):
initial = tf.truncated_normal(shape