RNN对MNIST数据集的识别

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
def build_graph():
    images = tf.placeholder(tf.float32, [None, 28, 28])
    labels = tf.placeholder(tf.float32, [None, 10])
    weight = {
        'first': tf.Variable(tf.random_normal([28, 128])),
        'last': tf.Variable(tf.random_normal([128, 10]))
            }
    bias = {
        'first': tf.Variable(tf.constant(0.1, shape=[128, ])),
        'last': tf.Variable(tf.constant(0.1, shape=[10, ]))
            }
    X = tf.reshape(images, shape=[-1, 28])
    X_in = tf.matmul(X, weight['first']) + bias['first']
    X_in = tf.reshape(X_in, shape=[-1, 28, 128])

    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=128)
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(128, forget_bias=1.0, state_is_tuple=True)
    init_state = lstm_cell.zero_state(128, dtype=tf.float32)

    outputs,final_state= tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
    # results = tf.matmul(final_state[1], weight['last']) + bias['last']
    results = tf.layers.dense(outputs[:, -1, :], 10)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=results)
    global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
    train_op = tf.train.AdamOptimizer(0.01).minimize(loss,global_step = global_step)

    correct_pred = tf.equal(tf.argmax(results, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return {'images': images,
            'labels': labels,
            'train_op': train_op,
            'global_step':global_step,
            'accuracy': accuracy}

if __name__=='__main__':
    mnist = input_data.read_data_sets('/home/ly/MNIST_data', one_hot=True)
    with tf.Session() as sess:
        graph = build_graph()
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                batch_x, batch_y = mnist.train.next_batch(128)
                batch_x = batch_x.reshape([128, 28, 28])

                feed_dict = {graph['images']: batch_x,
                             graph['labels']: batch_y}
                _, step, accuracy = sess.run(
                         [graph['train_op'], graph['global_step'], graph['accuracy']],
                         feed_dict=feed_dict)
                if step % 50 == 0:
                    print(accuracy)
                if step > 1200:
                    break
        except tf.errors.OutOfRangeError:
            print('==================Train Finished================')
        finally:
            coord.request_stop()
        coord.join(threads)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值