LSTM-tf.nn.static_rnn与tf.nn.dynamic_rnn.用法详解

34 篇文章 0 订阅
30 篇文章 2 订阅

最近研究LSTM的网络,想将LSTM应用到图像上,查资料发现,用到图像上的LSTM叫ConvLSTM,在这里记录下最核心的两个函数用法:
tf.nn.static_rnn与tf.nn.dynamic_rnn.

这两个函数是tensoflow针对RNN的LSTM提供的两个函数,两个函数的功能上其实差不多,但是tf.nn.dynamic()函数更加灵活.这里我还主要讲解函数用法没和两个函数输入数据形式的区别:

1.tf.nn.dynamic_rnn

def dynamic_rnn(cell, inputs, sequence_length=None, 
        initial_state=None,dtype=None,
        parallel_iterations=None,swap_memory=False,time_major=False, scope=None):
return output,state

parameter:
cell:参数:cell,自己定义的LSTM的细胞单元,如果是convLSTM,自己写也可以,.
下面两个链接提供cell函数(都可以用):https://github.com/TakuyaShinmura/conv_lstm/blob/master/conv_lstm_cell.py
https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py

也可以用tensorflow自带的API:tf.contrib.rnn.ConvLSTMCell()

inputs:一个5维的变量,[batchsize,timestep,image.shape],搭配time_major=False.这里还补充一点,就是叫dynamic的原因,就是输入数据的time_step不一定要相同,如果长短不一,会自动跟短的补0,但是处理时候,不会处理0,在0前面就截止了.这就是dynamic对比static的好处.

time_major: If true, these Tensors must be shaped [max_time, batch_size, depth].
If false, these Tensors must be shaped `[batch_size, max_time, depth]
其实很好理解,如果是true,就是time_step是主导,最前面就是max_time,如果是false,batch_size占主导,batch_size在前面,就是我上面的5维变量输入形式.
其他的参数都可以不用设置,默认就行.

最后所以下函数返回值:这里output是每个cell输出的叠加,比如我输入数据[1,5,100,100,3],是一个长度为5 的视频序列,则返回output为[1,5,100,100,3],5个cell细胞的输出状态,state是一个元组类型的数据,有(c和h两个变量)就是存储LSTM最后一个cell的输出状态,我一般用的是output的最后一个输出..用state输出也行,就是取元组中的h变量.
用下列语句输出:

 outputs = tf.transpose(outputs,[1,0,2,3,4])#这一步必不可少,将max_time提前,后面的output[-1]才是最后一个time的输出,也就是最后一个cell的输出
 last_output=outputs[-1]
然后在处理last_output
'''''''
处理过程
'''''''
return result

2.tf.nn.static_rnn()

def static_rnn(cell, inputs,initial_state=None, dtype=None, 
         sequence_length=None, scope=None)

return outputs,state

参数:cell,自己定义的LSTM的细胞单元,如果是convLSTM,自己写也可以,.
下面两个链接提供cell函数(都可以用):https://github.com/TakuyaShinmura/conv_lstm/blob/master/conv_lstm_cell.py
https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py

也可以用tensorflow自带的API:tf.contrib.rnn.ConvLSTMCell()

参数:input,和上面的tf.nn.dynamic_rnn有很大不同,这里的input输入是一个List,记住,是list,也就是输入是一个[ ],里面每个List元素都是一组图片,比如[iamges1,images2,images3],images 是有多张图片的一个图片序列.

sequence_length:序列长度,可以不设置,因为input里面可以指定.

最后是返回值:output也是一个list,list的每个元素对应每个image1,image2…的输入,取最后一个list元素,也就是outputs[-1],就是最后一个cell的输出.state和上面一样,记录最后一个cell的状态.

last_output=outputs[-1]
然后在处理last_output
'''''''
处理过程
'''''''
return result

最后这个链接,直接用代码总结,也很好:https://manutdzou.github.io/2017/11/27/tensorflow-lstm.html

最后是一个torch版本的综合应用,弄懂有点困难:https://github.com/viorik/ConvLSTM

好了,相信已经明白怎么使用这个ConLSTM结构了,本人也是初学,有不对的地方多多留言交流.

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在 TensorFlow 2.x 中,`tf.keras.rnn.static_rnn` 函数已经被弃用,如果想要在 TensorFlow 2.x 中使用 RNN 网络,可以使用 `tf.keras.layers.RNN` 类,或者使用 `tf.keras.layers.SimpleRNN`、`tf.keras.layers.LSTM`、`tf.keras.layers.GRU` 等层来构建 RNN 网络。 如果你想使用类似 `tf.keras.rnn.static_rnn` 的函数,可以使用 `tf.compat.v1.nn.static_rnn` 函数来代替,如下所示: ```python import tensorflow as tf # 定义 RNN 网络结构 cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(num_units=128) inputs = tf.random.normal(shape=(32, 10, 10)) seqTimeSteps = tf.unstack(inputs, axis=1) outputs, state = tf.compat.v1.nn.static_rnn(cell, seqTimeSteps, dtype=tf.float64) # 定义模型 model = tf.keras.Model(inputs=inputs, outputs=outputs) # 编译模型 model.compile(optimizer='adam', loss='mse', metrics=['mae']) # 训练模型 x_train = tf.random.normal(shape=(32, 10, 10)) y_train = tf.random.normal(shape=(32, 10, 128)) model.fit(x_train, y_train, batch_size=8, epochs=10) ``` 在上面的代码中,我们首先使用 `tf.compat.v1.nn.rnn_cell.BasicRNNCell` 定义了一个基础的 RNN 单元,然后使用 `tf.compat.v1.nn.static_rnn` 函数对输入数据进行处理。最后使用 `tf.keras.Model` 定义了一个完整的模型,并使用 `model.compile` 和 `model.fit` 进行模型的编译和训练。需要注意的是,在使用 `tf.compat.v1.nn.static_rnn` 函数时,需要将输入数据转换为 `list` 类型的张量,并且需要指定 `dtype` 参数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值