# -*- coding: UTF-8 -*-
'''
Created on 2017年12月8日
'''
#以下两句用于下载数据
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) #下载并加载mnist数据
#输入输出占位符
x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量,占位符,每一个sample都是784维,none表示可以有任意个sample
y_ = tf.placeholder("float", [None,10]) #占位符,每一个sample都是10维,因为是one_hot
#参数
W = tf.Variable(tf.zeros([784,10])) #权重,初始化值为全零,变量
b = tf.Variable(tf.zeros([10])) #偏置,初始化值为全零,变量
#进行模型建立及计算,y是预测,y_ 是实际
y = tf.nn.softmax(tf.matmul(x,W) + b)
#计算交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))
tf.scalar_summary('cross_entropy',cross_entropy)
#接下来使用BP算法来进行微调,以0.01的学习速率,使用的是简单的梯度下降算法----记住,这是一个优化算子
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#上面设置好了模型,添加初始化创建变量的操作
init = tf.initialize_all_variables()
#启动创建的模型,并初始化变量
sess = tf.Session()
sess.run(init) #init也是操作
merged = tf.merge_all_summaries() #collect the tf.xxxxx_summary
writer = tf.train.SummaryWriter('/home/tensorBoardLog/MNISTone',sess.graph)
#开始训练模型,循环训练1000次
for i in range(1000):
#随机抓取训练数据中的100个批处理数据点
batch_xs, batch_ys = mnist.train.next_batch(100) #mnist.train #这里边有疑问:next_batch is a method of the DataSet class
#https://stackoverflow.com/questions/40368697/where-does-next-batch-in-the-tensorflow-tutorial-batch-xs-batch-ys-mnist-trai
#可以在github上看到
summary,loss, _= sess.run([merged, cross_entropy, train_step], feed_dict={x:batch_xs,y_:batch_ys}) #train_step是一个操作,step表示每一步
#注意是操作;给模型必要的输入,以及必要的操作指示
writer.add_summary(summary, i)
print('range: %04d, loss = %-9f' % (i+1, loss))
''''' 进行模型评估 '''
#判断预测标签和实际标签是否匹配
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #1表示在1轴上,0轴表示的是样本index
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #tf.cast是类型转换函数
#计算所学习到的模型在测试数据集上面的正确率
print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) ) #mnist.test,注意了,x,y只是占位符,train test都可以用
#accuracy,注意了,这个时候w,b不会再变了,所以x进去自然会有一个y出来;feed_dict表示输入字典
ss