八、(1)神经网络——简单神经网络,手写数字预测。
本文通过预测手写数字来练习简单的神经网络。主要步骤为训练模型、保存模型、预测数据集、将数据集转图片、将图片转成数据集格式。
第一步:训练模型并保存。
"""
Created on Mon May 27 23:55:27 2019
@author: sun
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer("is_train", 1, "指定程序是预测还是训练")
def full_connected():
mnist = input_data.read_data_sets("./data/mnist/input_data/", one_hot=True)
with tf.variable_scope("data"):
x = tf.placeholder(tf.float32, [None, 784])
y_true = tf.placeholder(tf.int32, [None, 10])
with tf.variable_scope("fc_model"):
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name="w")
bias = tf.Variable(tf.constant(0.0, shape=[10]))
y_predict = tf.matmul(x, weight) + bias
with tf.variable_scope("soft_cross"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
with tf.variable_scope("optimizer"):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
with tf.variable_scope("acc"):
equal_list = tf.equal(tf.argmax(y_true, 1