RNN中网络结构的理解

在使用tensorflow对构建rnn模型的时候,有几个参数一直不能很好的理解它本身的结构,这对后续网络的修改产生了很大的问题,在网上查阅资料后对其中一些参数结构进行总结。
例子代码如下:

#!/usr/bin/env python3
# -*- coding:utf-8 -*-


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def Weight_Variable(shape):
    initial = tf.truncated_normal(shape=shape, stddev=0.1)
    return tf.Variable(initial)


def Bias_Variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def RNN(x, weights, bias, n_times, n_inputs, n_hidden_units):
    # inputs shape:(100,28,28)
    inputs = tf.reshape(x, [-1, n_times, n_inputs])
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units)
    # todo tensorflow 删除了core_rnn_cell
    # todo lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell
    # output shape:(100,28,100) finall_state为一个包含两个元素的tuple,其中每个元素的shape都为(100,100)
    output, finall_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)
    # predicton shape:(100,10)
    prediction = tf.matmul(finall_state[1], weights) + bias
    return prediction


def main():
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    batch_size = 100
    n_batch = mnist.train.num_examples // batch_size
    n_inputs = 28
    n_times = 28
    n_hidden_units = 100
    n_classes = 10
    x = tf.placeholder(tf.float32, [None, n_inputs * n_times])
    y = tf.placeholder(tf.float32, [None, n_classes])
    weights = Weight_Variable([n_hidden_units, n_classes])
    bias = Bias_Variable([n_classes])
    prediction = RNN(x, weights, bias, n_times, n_inputs, n_hidden_units)
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
    optimizer = tf.train.AdamOptimizer(1e-4)
    train = optimizer.minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(11):
            for batch in range(n_batch):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                # todo batch_xs.shape=(batch_size,n_input*n_times)
                sess.run(train, feed_dict={x: batch_xs, y: batch_ys})
            acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
            print('Iter:' + str(epoch + 1) + 'Testing accuracy=' + str(acc))


if __name__ == '__main__':
    main()

以MNIST手写数字识别数据集为例,原数据集中每个图片大小均为[28,28],我们设置其中的batch_size为100,原始输入数据就是[100,28*28],进行reshape操作为模型可以处理的数据[100,28,28]。
最关键的是tf.contrib.rnn.BasicLSTMCell(n_hidden_units)中n_hidden_units的含义,查阅资料后得知为网络输出的向量维数。RNN中我们每个step输入28维的数据,每个step输出100维的数据,output输出每个step的结果,所以最后output的shape为[100,28,100],finall_state为包含两个状态(c_state,h_state)的元组,其中每个状态仅含有最后一个step的数据,所以两者的shape都为[100,100],其中m_state的内容和output中每个batch最后一行(也就是最后一个step)一样。

参考链接https://xdrush.github.io/2018/02/12/RNN原理详解以及tensorflow中的RNN实现/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值