刚入门tensorflow,最近在看黄文坚的《Tensorflow实战》,看到“Tensorflow实现多层感知机”这节,就手动把代码实现了一下。
导入mnist数据集的时候,不知道为什么总是很慢很慢,一直运行不停,所以就把数据集下载到本地了。网上有人提供云盘链接,搜一下就好。
下面代码中 ‘./MNIST_data/’ 是我电脑中mnist的路径(下载的时候自己选择),最前面的 ‘.’ 表示当前目录(如果是两个点,表示当前目录的上一级)。后面的 ‘/’ 表示读取该文件夹下的所有文件。注意,下载好的4个压缩文件不用解压缩!!!运行read_data_sets的时候会自动解压的。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist=input_data.read_data_sets('./MNIST_data/',one_hot=True)
##########构建神经网络第一步,前向建立计算图##########
#设置参数Variable并进行初始化
#输入层和hidden layer1的神经元数量分别为784和300
in_units=784
h1_units=300
#权重和偏置的初始化
W1=tf.Variable(tf.truncated_normal([in_units,h1_units],stddev=0.1))
b1=tf.Variable(tf.zeros([h1_units]))
W2=tf.Variable(