Tensorflow API中LSTM的参数提取及手工复现模型推理(不使用API复现)

本文介绍了如何从Tensorflow 1.14及之前版本的模型中提取LSTM参数,并手动复现推理过程,特别是针对没有GPU或无法使用API的场景。文章详细解析了LSTM的结构,并提供了两种API的代码示例,强调在推理时关键接口是`inputs`和`num_units`。最后,给出了无API推理的步骤,包括参数读取和计算过程。
摘要由CSDN通过智能技术生成

Tensorflow API中LSTM的参数提取及手工复现模型推理(不使用API复现)

对于深度学习任务,使用Tensorflow的API进行训练,推理是目前主流的实现方式。但是训练模型和运用模型进行推理,往往处于不同的工作场景。比如训练模型使用服务器的GPU集群进行加速训练,而通常希望运用模型的场景是在嵌入式设备上,ARM、FPGA、或者像我的需求一样,需要设计ASIC来进行加速推理(反正应用场景就是没有GPU也多半装不上TensorFlow )。
附:本文不提供LSTM模型训练的教程,本文适用于需要从Tensorflow 模型中提取LSTM参数,并且以及需要手动复现的读者。代码由Python编写。提取的模型适用于TensorFlow 1.14及之前的版本。介绍的两种LSTM API手动复现时偏置计算时有细微区别,请读者注意。

LSTM的结构

由于这些设备上往往只能使用训练好的参数进行计算,但是成本限制或者产品需求不能使用API,因此必须要搞清楚网络结构,并且手动复现这个计算过程,以及提取出模型保存好的参数来进行推理。首先介绍LSTM的基本结构:
[结构图引自博客 ]https://blog.csdn.net/kami0116/article/details/94749564.
LSTM基本结构
其中的c_prev和h_prev是前一时刻的状态,c,h是当前时刻的输出。通常LSTM中会有多个隐藏层,隐藏层可以理解成多个LSTM的串接,前一个LSTM的输出作为下一个LSTM的输入。但是在实际应用中,如FPGA或者ASIC实现,从硬件思维来考虑,由于LSTM对于时间连续性的要求,使得多个LSTM和单个LSTM在算一路数据而言,其耗时相当。c_prev和h_prev是一组向量,而第一个LSTM单元的c_prev和h_prev被默认为全为0的向量值。
Wf,Wi,Wj,Wo则是LSTM的权重信息,j 在文献中通常会用C来表示,为了区分以及方便对照着这个结构图写代码,此处用 j 表示。

formula
公式中括号的[x,hprev]表示向量的拼接,[x,hprev] · Wf ,代表矩阵乘法,硬件中的操作是MAC(乘累加),*则是矩阵中相同位置的数字相乘没有累加这一操作。σ是sigmoid激活函数,tanh也是激活函数。

LSTM结构介绍完毕,代码分析。

TensorFlow的API中,有两个构造LSTM的函数:
1:

tf.contrib.rnn.BasicLSTMCell(
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值