LSTM实战:识别手写体数字

python3

pycharm

代码:

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

input_vec_size = lstm_size = 28 #输入向量的维度
time_step_size = 28 #循环层长度

batch_size =128
test_size = 256

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape,stddev = 0.01))

def model(X,W,B,lstm_size):
    XT = tf.transpose(X,[1,0,2])
    XR = tf.reshape(XT,[-1,lstm_size])
    X_split = tf.split(XR,time_step_size,0)
    lstm = rnn.BasicLSTMCell(lstm_size,forget_bias = 1.0,state_is_tuple = True)
    outputs,_states = rnn.static_rnn(lstm,X_split,dtype = tf.float32)
    return tf.matmul(outputs[-1],W) + B , lstm.state_size
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)
trX,trY,teX,teY = mnist.train.images,mnist.train.labels,mnist.test.images,mnist.train.labels

trX = trX.reshape(-1,28,28)
teX = trX.reshape(-1,28,28)

X = tf.placeholder("float",[None,28,28])
Y = tf.placeholder("float",[None,10])
W = init_weights([lstm_size,10])
B = init_weights([10])

py_x,state_size = model(X,W,B,lstm_size)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x,labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001,0.9).minimize(cost)
predict_op = tf.argmax(py_x,1)

session_conf = tf.ConfigProto()
session_conf.gpu_options.allow_growth = True

with tf.Session(config = session_conf) as sess:
    tf.global_variables_initializer().run()
    for i in range(100):
        for start,end in zip(range(0,len(trX),batch_size),range(batch_size,len(trX)+1,batch_size)):
            sess.run(train_op,feed_dict={X:trX[start:end],Y:trY[start:end]})
        s = len(teX)
        test_indices = np .arange(len(teX))
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i,np.mean(np.argmax(teY[test_indices],axis = 1) == sess.run(predict_op,feed_dict={X:teX[test_indices]})))

实验结果:

源代码: 
https://github.com/geroge-gao/deeplearning/tree/master/LSTM

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值