本文参考《TensorFlow技术解释与实战》一书,感谢李嘉璇大佬对社区的贡献
python版本3.5
tensorflow版本1.3 CPU版本
#coding=utf-8
from __future__ import print_function
from tensorflow.examples.tutorials.mnist import input_data
print("开始下载mnist数据集")
mnist = input_data.read_data_sets("E:\pycharm\work/MNIST_data", one_hot=True)
print("mnist数据集下载成功")
import tensorflow as tf
# 定义网络超参数
learning_rate = 0.001
training_iters = 200000
batch_size = 64
display_step = 20
# 定义网络参数
n_input = 784 # 输入的维度 28 * 28
n_classes = 10 # 标签的维度 0-9
dropout = 0.8 # Dropout 的概率,输出的可能性
# 占位符输入
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)
'''
构建网络模型
接下来定义AlexNet