用LSTM实现手写图片的数字识别
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
lr = 0.001
training_iters = 100000
batch_size = 128
n_inputs = 28 # 输入的是每一行有28个像素
n_steps = 28 # 一共有28行,图片28x28
n_hidden_units = 128
n_classes = 10 # 10分类
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
weights = {
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
def rnn(x_input, weights_input, biases_input):
x = tf.reshape(x_input, [-1, n_inputs])
x_in = tf.matmul(x, weights_input['i