之前没接触过RNN,直接看代码https://github.com/siyuanzhao/2016-EDM
这是关于知识追踪Going Deeper with Deep Knowledge Tracing的官方代码
使用TensorFlow0.10.0,那是我还没学编程呢
所以能看懂大概作用的函数就不细查,关于网络构造还是得看看。
RNN函数
tf.nn.rnn_cell.LSTMCell(final_hidden_size, state_is_tuple=True)
state_is_tuple:默认为True,接受状态和返回状态是(c_state,m_state)元组,即状态
c
t
c^t
ct和
h
t
h^t
ht分开记录。如果为False,则沿列轴连接它们,只返回一个concate([c_state,m_state],axis=-1)
tf.nn.rnn_cell.MultiRNNCell __init__(cells,state_is_tuple=True)
hidden_layers = []
tf.nn.rnn_cell.MultiRNNCell(hidden_layers, state_is_tuple=True)
cells,rnn类单元的list。list的大小就是网络层数的多少
state_is_tuple:与LSTM中的参数一样,不过这里返回的是n个(c_state, m_state)元组,n的大小也为上面cell list的长度。
下面是两个博主的代码:
import tensorflow as tf
import numpy as np
num_units = [50, 200, 300]
cells = [tf.nn.rnn_cell.LSTMCell(num_unit) for num_unit in num_units]
mul_cells = tf.nn.rnn_cell.MultiRNNCell(cells)
print(mul_cells.state_size)
input = np.random.rand(32, 100)
inputs = tf.constant(value=input, shape=(32, 100), dtype=tf.float32)
h0 = mul_cells.zero_state(32, np.float32)
output, h1 = mul_cells.__call__(inputs, h0)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output))
print(sess.run(tf.shape(output)))
print(sess.run(tf.shape(h1[0].c)))
print(sess.run(tf.shape(h1[1].c)))
print(sess.run(tf.shape(h1[2].c)))
reference: https://zhuanlan.zhihu.com/p/99421590
import tensorflow as tf
batch_size=10
depth=128
inputs=tf.Variable(tf.random_normal([batch_size,depth]))
previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))
num_units=[100,200,300]
print(inputs)
cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]
mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)
outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))
print(outputs.shape) #(10, 300)
print(states[0]) #第一层LSTM
print(states[1]) #第二层LSTM
print(states[2]) ##第三层LSTM
print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)
(10, 300)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_0/basic_lstm_cell/Add_1:0' shape=(10, 100) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_0/basic_lstm_cell/Mul_2:0' shape=(10, 100) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_1/basic_lstm_cell/Add_1:0' shape=(10, 200) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_1/basic_lstm_cell/Mul_2:0' shape=(10, 200) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor 'multi_rnn_cell/cell_2/basic_lstm_cell/Add_1:0' shape=(10, 300) dtype=float32>, h=<tf.Tensor 'multi_rnn_cell/cell_2/basic_lstm_cell/Mul_2:0' shape=(10, 300) dtype=float32>)
(10, 100)
(10, 100)
(10, 200)
旧版本应该要tf.Session()吧 自行测试
reference:https://www.cnblogs.com/yanshw/p/10515436.html
https://www.zhihu.com/people/simon-29-55-12/posts
解释了一连串的tf.nn.的函数