DNN-mnist数据集识别
win10
python3.6
tensorflow1.12
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot=True)
# 超参设置
learning_rate = 0.01
num_steps = 500
batch_size = 128
display_step = 100
# 网络参数
n_hidden_1 = 256
n_hidden_2 = 256
num_input = 784
num_classess = 10
# 输入设置
X = tf.placeholder("float", [None, num_input])
Y = tf.placeholder("float", [None, num_classess])
# 网络权重设置
weights = {
"h1" : tf.Variable(tf.random_normal([num_input, n_hidden_1]))