# http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html #极客学院 import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) #下载很慢,会出错。可以提前到网上下好,放到相关文件夹下 import tensorflow as tf x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) #[1,10] y = tf.nn.softmax(tf.matmul(x,W) + b) #softmax-- 使标准化 ,y--预测 y_ = tf.placeholder("float", [None,10]) #y_ :真实值 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) #交叉熵 optimization = tf.train.GradientDescentOptimizer(0.01) #更新权值 train_step = optimization.minimize(cross_entropy) #模型评估 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) #返回最大的那个数值所在的下标 accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) #布尔值转换成浮点数,然后取平均值 #初始化变量 init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) for i in range(1000): ###batch_xs: (100,784), batch_ys: (100,10), <class 'numpy.ndarray'> ## #mnist.test.images:(10000,784), mnist.test.labels:(10000,10) , <class 'numpy.ndarray'> # batch_x,batch_y = mnist.test.next_batch(100) ##batch_x:(100,784) , batch_y:(100,10) batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) Z = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) print(Z)
Mnist1--Simple
最新推荐文章于 2021-09-24 10:57:35 发布