一个简单的DEMO:实现手写数字图片的识别
单向LSTM
利用的数据集是tensorflow提供的一个手写数字数据集。该数据集是一个包含55000张28*28的数据集。
训练100次
识别准确率还不是很稳定,但是从第17次开始就趋于相对稳定的状态了。
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
#import input_data
from tensorflow.examples.tutorials.mnist import input_data #####
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #####
# configuration
# O * W + b -> 10 labels for each image, O[? 28], W[28 10], B[10]
# ^ (O: output 28 vec from 28 vec input)
# |
# +-+ +-+ +--+
# |1|->|2|-> ... |28| time_step_size = 28
# +-+ +-+ +--+
# ^ ^ ... ^
# | | |
# img1:[28] [28] ... [28]
# img2:[28] [28] ... [28]
# img3:[28] [28] ... [28]
# ...
# img128 or img256 (batch_size or test_size 256)
# each input size = input_vec_size=lstm_size=28
# configuration variables
input_vec_size = lstm_size = 28 # 输入向量的维度
time_step_size = 28 # 循环层长度
batch_size = 128
test_size = 256
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01))
def model(X, W, B, lstm_size):
# X, input shape: (batch_size, time_step