tf.nn.rnn_cell.MultiRNNCell

之前没接触过RNN,直接看代码https://github.com/siyuanzhao/2016-EDM
这是关于知识追踪Going Deeper with Deep Knowledge Tracing的官方代码
使用TensorFlow0.10.0,那是我还没学编程呢
所以能看懂大概作用的函数就不细查,关于网络构造还是得看看。

RNN函数

在这里插入图片描述

tf.nn.rnn_cell.LSTMCell(final_hidden_size, state_is_tuple=True)
state_is_tuple:默认为True,接受状态和返回状态是(c_state,m_state)元组,即状态 c t c^t ct h t h^t ht分开记录。如果为False,则沿列轴连接它们,只返回一个concate([c_state,m_state],axis=-1)


tf.nn.rnn_cell.MultiRNNCell __init__(cells,state_is_tuple=True)

hidden_layers = []
tf.nn.rnn_cell.MultiRNNCell(hidden_layers, state_is_tuple=True)

cells,rnn类单元的list。list的大小就是网络层数的多少
state_is_tuple:与LSTM中的参数一样,不过这里返回的是n个(c_state, m_state)元组,n的大小也为上面cell list的长度。

下面是两个博主的代码:

import tensorflow as tf
import numpy as np

num_units = [50, 200, 300]
cells = [tf.nn.rnn_cell.LSTMCell(num_unit) for num_unit in num_units]
mul_cells = tf.nn.rnn_cell.MultiRNNCell(cells)
print(mul_cells.state_size)

input = np.random.rand(32, 100)
inputs = tf.constant(value=input, shape=(32, 100), dtype=tf.float32)

h0 = mul_cells.zero_state(32, np.float32)
output, h1 = mul_cells.__call__(inputs, h0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(output))
    print(sess.run(tf.shape(output)))
    print(sess.run(tf.shape(h1[0].c)))
    print(sess.run(tf.shape(h1[1].c)))
    print(sess.run(tf.shape(h1[2].c)))

在这里插入图片描述

reference: https://zhuanlan.zhihu.com/p/99421590

import tensorflow as tf

batch_size=10
depth=128

inputs=tf.Variable(tf.random_normal([batch_size,depth]))

previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))

num_units=[100,200,300]
print(inputs)

cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]
mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)

outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))

print(outputs.shape) #(10, 300)
print(states[0]) #第一层LSTM
print(states[1]) #第二层LSTM
print(states[2]) ##第三层LSTM
print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)
(10, 300)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_0/basic_lstm_cell/Add_1:0' shape=(10, 100) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_0/basic_lstm_cell/Mul_2:0' shape=(10, 100) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_1/basic_lstm_cell/Add_1:0' shape=(10, 200) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_1/basic_lstm_cell/Mul_2:0' shape=(10, 200) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_2/basic_lstm_cell/Add_1:0' shape=(10, 300) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_2/basic_lstm_cell/Mul_2:0' shape=(10, 300) dtype=float32>)
(10, 100)
(10, 100)
(10, 200)

旧版本应该要tf.Session()吧 自行测试
reference:https://www.cnblogs.com/yanshw/p/10515436.html

https://www.zhihu.com/people/simon-29-55-12/posts
解释了一连串的tf.nn.的函数

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值