Lstm实现MNIST手写字体识别(tensorflow版)

本文主要讲解如何使用Lstm实现MNIST手写字体识别,之前写过用cnn实现,这次我们使用lstm循环神经网络来实现。一切尽在代码中。
导入相关库以及MNIST数据

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials import mnist
from tensorflow.contrib import rnn
mnist = mnist.input_data.read_data_sets('./data',one_hot=True)

接下来我们设定参数
这里写图片描述

input_x就是图片中输入X的序列,相当于每一个输入X,都是1×28的大小,time_steps就是图片中绿色框的个数,图中用A表示的,也就是说总共有28个,因为图像是28×28

lr = 0.01  # 学习率
epochs = 30
batch_size = 64
input_x = 28 # 输入序列的长度,因为是28×28的大小,所以每一个序列我们设置长度为28,每一个输入都是28个像素点
time_steps = 28 #因为没有张图像为28×28,而每一个序列长度为1×28,所以总共281×28,
output_y = 10  #输入为10,因为共10
hidden_n = 128 #隐层的大小,这个参数就是比如我们输入是1×28的矩阵大小,隐藏为128,就是将输入维度变为1×128,当然lstm输入也是1×128

设置好placeholder,以便feed
第一个是输入,shape第一个参数是batch_size,第二个相当于序列总数,第三个是每个序列的长度

x = tf.placeholder(dtype=tf.float32,shape=[None,time_steps,input_x],name='input_x')
y = tf.placeholder(dtype=tf.float32,shape=[None,output_y],name='output_y')

设置好lstm输入的权重,因为输出的是1×128的矩阵,我们需要变为1×10的矩阵,因为类别为10,

weights = tf.Variable(tf.random_normal([hidden_n,output_y]),name='weights')
bais  = tf.Variable(tf.zeros([output_y]),name='bais')

上述工作做好以后,我们就开始定义LSTM了,其中stack函数的讲解看tf.unstack

def built_rnn(input_data,weights,bais):
    # 此处是对数据进行处理,因为我们输入数据是batch_size,time_steps,input_x,但是在喂给lstm的时候,我们是time_steps个28×1喂,
    #相当于一个list中包括time_steps个28×1

    input_data = tf.unstack(input_data,time_steps,axis=1)

    lstm_cell = rnn.BasicLSTMCell(hidden_n,forget_bias=1)

    lstm_output , lstm_state = rnn.static_rnn(lstm_cell,input_data,dtype=tf.float32)

    # 这个地方选择lstm_output[-1],也就是相当于最后一个输出,因为其实每一个cell(相当于图中的A)都会有输出,但是我们只关心最后一个
    #输出,
    output = tf.add(tf.matmul(lstm_output[-1],weights),bais)
    return output

接下来我们开始定义损失函数以及优化器

y_pred = built_rnn(x,weights,bais)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred,labels=y),name=loss)
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits=y_pred,1),tf.argmax(y,1))),tf.float32)

然后就开始训练模型了:

with tf.Session(config=tf.ConfigProto(device_count={'gpu':0})) as sess:
    sess.run(tf.global_variables_initializer())
    n_batch = int(mnist.train.num_examples/batch_size)
    writer = tf.summary.FileWriter('./graphs/lstm_mnist',sess.graph)
    for epoch in range(epochs):
        total_loss = 0
        for i in range(n_batchsize):
            xs,ys = mnist.train.next_batch(batch_size)
            # 因为xs的shape是(None,784)的我们需要reshape一下
            xs = xs.reshape((batch_size,time_steps,input_x))


            _,tmp_loss = sess.run([optimizer,loss],feed_dict={x:xs,y:ys})

            total_loss += tmp_loss
        if epoch%10 == 0:
            train_data = mnist.train.images.reshape((-1,28,28))
            train_acc = sess.run(accuracy,feed_dict={x:train_data,y:mnist.train.labels})
            val_data = mnist.validation.images.reshape((-1,28,28))
            val_loss,val_acc = sess.run([loss,accuracy],feed_dict={x:val_data,y:mnist.validation.labels})

            print('Epoch  {}/{:.3f} train loss {:.3f},train_acc {:.3f},val_loss {:.3f},val_acc {:.3f} '.format(epoch,epochs,total_loss/n_batch,train_acc/n_batch,val_loss,val_acc))
    test_data = mnist.test.images.reshape((-1,28,28))
    test_loss,test_acc = sess.run([loss,accuracy],feed_dict={x:test_data,y:mnist.test.labels})
    print('test_loss {:.3f},test_acc {:.3f}'.format(test_loss,test_acc))
    writer.close()

这样就做完了

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值