深度学习之LSTM:基于TF模型参数的C语言前向算法实现
大多数关于循环神经网络的介绍,基本上都是在着重讲解循环层节点的内部结构,其实这样的讲解不能说不好,但是对于刚开始学习循环神经网络的人来说,是不太友好的。所以结合自己的学习习惯,本文打算以网络中的数据流向来说明整个循环神经网络的运作过程,然后再介绍LSTM网络节点结构,接着展示一个TensorFlow中实现LSTM的例子,并讲解TF中LSTM模型参数的细节,最后用C语言(基于TF模型参数)来实现LSTM的前向计算。
文章目录
一、数据准备
我们以文本分析为例,来讲述网络中数据的流转,第四节基于TensorFlow的LSTM实现中,用的是MNIST数据库。
好了,假设我们要处理的句子是:
I
l
o
v
e
D
e
e
p
L
e
a
r
n
i
n
g
I \ love\ Deep\ Learning
I love Deep Learning。
由于文本信息无法直接输入网络,所以首先我们要对上述句子做向量化处理。假设单个单词向量化后词向量维度为
5
5
5,且已知句子长度为
4
4
4。则转化后整个句子可表示为
V
=
[
v
0
,
v
1
,
v
2
,
v
3
]
V=[v_0,v_1,v_2,v_3]
V=[v0,v1,v2,v3],它是一个
4
×
5
4×5
4×5 的矩阵,此处需要注意 “行” 表示的是序列的长度,而 “列” 表示的是特征的维度。
其实,其他类型的时间序列的处理也是类似的,只不过分帧的方式和特征的提取不一样而已,重要的是我们要确认序列长度和特征维度。
二、网络中的数据流
在弄清楚数据在网络中的流转前,我们先定义一下网络结构。首先根据第一节已知特征维度为
5
5
5,则输入层节点数也为
5
5
5。而隐藏层我们定义了有
2
2
2 个节点的LSTM循环层,最后再接一个只有
1
1
1 个节点的输出层。网络结构图如下,其中第二层LSTM循环层各节点,有一个指向自己的箭头,因为作图工具(NN SVG)只能画全连接网络,故图中箭头为手动添加。
在向网络中输入数据之前,我们还要确认一件事情,即隐藏状态
h
h
h 和节点状态
C
C
C 的数目。由于我们定义的LSTM循环层只有两个节点,故节点状态
C
C
C 的数目为
2
2
2;而隐藏状态
h
h
h 的数目和
C
C
C 是一样的,也等于循环层的节点数。
由前一节可知,我们的输入样本为
V
=
[
v
0
,
v
1
,
v
2
,
v
3
]
V=[v_0,v_1,v_2,v_3]
V=[v0,v1,v2,v3],且令两个LSTM节点中的隐藏状态和节点状态的初始值分别为
h
0
0
,
h
1
0
,
c
0
0
,
c
1
0
h_00,h_10,c_00,c_10
h00,h10,c00,c10。
第一步,将
v
0
v_0
v0 输入到网络中,LSTM节点中的各状态将被更新
h
0
1
,
h
1
1
,
c
0
1
,
c
1
1
h_01,h_11,c_01,c_11
h01,h11,c01,c11,注意,此时网络中的数据并不会流转到输出层;
第二步,将
v
1
v_1
v1 输入到网络中,LSTM节点中的各状态将被更新
h
0
2
,
h
1
2
,
c
0
2
,
c
1
2
h_02,h_12,c_02,c_12
h02,h12,c02,c12;
第三步,将
v
2
v_2
v2 输入到网络中,LSTM节点中的各状态将被更新
h
0
3
,
h
1
3
,
c
0
3
,
c
1
3
h_03,h_13,c_03,c_13
h03,h13,c03,c13;
第四步,将
v
3
v_3
v3 输入到网络中,LSTM节点中的各状态将被更新
h
0
4
,
h
1
4
,
c
0
4
,
c
1
4
h_04,h_14,c_04,c_14
h04,h14,c04,c14;
第五步,LSTM节点的输出数据(其实就是
h
0
4
,
h
1
4
h_04,h_14
h04,h14),流转到输出层节点,完成计算。
由上可知,一个样本将在循环层循环计算
N
N
N 次,(
N
N
N=序列的长度,本例中为4),当然每次计算输入的都是时间序列中不同时刻的特征向量,在完成第
N
N
N 次计算之前,循环层的输出数据并不会往下一层网络流转。
这大体上就是数据在LSTM网络中的流转路径,而数据在LSTM内部节点的计算,则需要我们对LSTM的节点结构有一定了解,这部分在第三小节讲述。
三、LSTM的节点结构
关于LSTM结构的介绍,可参考文章 Understanding LSTM Networks,本节图片均来自该文章,该文图片清晰明了又美观,故在此引用,并感谢原作者。
遗忘门:
计算公式为:
f
t
=
σ
(
W
f
x
t
+
U
f
h
t
−
1
+
b
f
)
f_t~ = \sigma(W_fx_t+U_fh_{t-1}+b_f)
ft =σ(Wfxt+Ufht−1+bf)
从公式可知,
f
t
∈
(
0
,
1
)
f_t\in(0,1)
ft∈(0,1)。当
f
t
→
0
f_t\to0
ft→0 时,意味着之前的节点记忆
C
t
−
t
C_{t-t}
Ct−t 将被遗忘;当
f
t
→
1
f_t\to1
ft→1 时,意味着之前的节点记忆
C
t
−
t
C_{t-t}
Ct−t 将保持。
输入门:
计算公式为:
i
t
=
σ
(
W
i
x
t
+
U
i
h
t
−
1
+
b
i
)
i_t~ = \sigma(W_ix_t+U_ih_{t-1}+b_i)
it =σ(Wixt+Uiht−1+bi);
C
~
t
=
σ
(
W
c
x
t
+
U
c
h
t
−
1
+
b
c
)
\widetilde C_t = \sigma(W_cx_t+U_ch_{t-1}+b_c)
C
t=σ(Wcxt+Ucht−1+bc)
同样可知
i
t
∈
(
0
,
1
)
i_t\in(0,1)
it∈(0,1)。当
i
t
→
0
i_t\to0
it→0 时,表示当前时刻的输入
x
t
x_t
xt 将被拒绝输入;而当
i
t
→
1
i_t\to1
it→1 时,意味着当前时刻的输入将被允许;
C
~
t
\widetilde C_t
C
t的计算,我们遵循上述博文,将其放在了输入门说明。但在TensorFlow的源码中,将之称为cell gate,它和
i
t
i_t
it 共同决定了当前时刻输入的多少;
有了
f
t
f_t
ft 和
i
t
i_t
it,它们将和
C
t
−
1
C_{t-1}
Ct−1,
C
~
t
\widetilde C_t
C
t 共同更新节点状态(cell state)
C
t
C_t
Ct
C
t
C_t
Ct 的更新计算公式为:
C
t
=
f
t
∗
C
t
−
1
+
i
t
∗
C
~
t
C_t = f_t*C_{t-1}+i_t*\widetilde C_t
Ct=ft∗Ct−1+it∗C
t,
输出门:
计算公式为:
o
t
=
σ
(
W
o
x
t
+
U
o
h
t
−
1
+
b
o
)
o_t = \sigma(W_ox_t+U_oh_{t-1}+b_o)
ot=σ(Woxt+Uoht−1+bo)
输出门的值
o
t
∈
(
0
,
1
)
o_t\in(0,1)
ot∈(0,1)。当
o
t
→
0
o_t\to0
ot→0 时,意味着当前时刻输出门是关闭的;而当
i
t
→
1
i_t\to1
it→1 时,意味着输出门是打开的。
输出值的计算遵循如下公式:
h
t
=
o
t
∗
t
a
n
h
(
C
t
)
h_t = o_t*tanh(C_t)
ht=ot∗tanh(Ct)。
四、基于TensorFlow的LSTM模型
# 加载MNIST数据库
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# train_flatten = [x.flatten() for x in x_train]
# 整理输入数据[样本数,序列长度,特征维度]
x_train,x_test = tf.reshape(x_train,[len(x_train),28,-1]),tf.reshape(x_test,[len(x_test),28,-1])
# 构建网络结构
model = tf.keras.models.Sequential([
tf.keras.layers.LSTM(2),
tf.keras.layers.Dense(10,activation="sigmoid")
],name="LSTM")
# 定义代价函数
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# 配置模型编译参数
model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])
# 训练模型,这里为了方便将epochs设置为了“1”
model.fit(x_train,y_train, epochs=1,batch_size=20)
model.summary() # 输出网络结构
# 保存模型
# tf.keras.models.save_model(model,filepath='F:/script/pyscript/learning/nn-learning/model/')
# 预测输出
predictions = model.predict(x_test[:1])
print("output from tf:",predictions[0])
在我的终端中某次模型(已保存)后运行的结果如下:
TensorFlow中,通过
print(model.weights)
打印整个模型参数可知,LSTM层有三个参数,分别是:
variable | shape | description |
---|---|---|
lstm/lstm_cell/kernel | 本例中大小为 (28, 8) | weights for cell kernel. |
lstm/lstm_cell/recurrent_kernel | 本例中大小为(2,8) | weights for cell recurrent kernel. |
lstm/lstm_cell/bias | 本例中大小为(8,) | weights for cell kernel bias and recurrent bias. |
接下来我们解释各参数的含义和内部细节。
kernel : 是当前输入
x
t
x_t
xt 与LSTM层节点内各门的连接权重,大小为(特征维度,LSTM节点数*4)。本文中特征的维度为28,LSTM层的节点数目为2,所以参数 kernel 的大小为(28,2*4)。kernel 中各列的顺序为
[
i
,
f
,
c
,
o
]
[i,f,c,o]
[i,f,c,o],其中
i
i
i 是输入门的权重,
f
f
f 是遗忘门的权重,
o
o
o 是输出门的权重,而
c
c
c 则是更新节点状态时的权重。当LSTM层有多个节点时,按上述顺序分别存储各门的参数,如本例中是这样的
[
W
i
0
,
W
i
1
,
W
f
0
,
W
f
1
,
W
c
0
,
W
c
1
,
W
o
0
,
W
o
1
]
[W_{i0},W_{i1},W_{f0},W_{f1},W_{c0},W_{c1},W_{o0},W_{o1}]
[Wi0,Wi1,Wf0,Wf1,Wc0,Wc1,Wo0,Wo1] 。
recurrent_kernel :是前一时刻的隐藏状态
h
t
−
1
h_{t-1}
ht−1与LSTM节点内各门的连接权重,大小为(LSTM节点数,LSTM节点数*4)。本文中LSTM层的节点数为2,所以,参数 recurrent_kernel 大小为(2,2*4)。参数的排列顺序与 kernel一样,与第三节各公式中的参数对应就是
[
U
i
0
,
U
i
1
,
U
f
0
,
U
f
1
,
U
c
0
,
U
c
1
,
U
o
0
,
U
o
1
]
[U_{i0},U_{i1},U_{f0},U_{f1},U_{c0},U_{c1},U_{o0},U_{o1}]
[Ui0,Ui1,Uf0,Uf1,Uc0,Uc1,Uo0,Uo1] 。
bias :是各门的偏置,大小为(LSTM节点数*4,)。本文中LSTM层的节点数为2,故 bias 的大小为(2*4,),参数顺序同上,
[
b
i
0
,
b
i
1
,
b
f
0
,
b
f
1
,
b
c
0
,
b
c
1
,
b
o
0
,
b
o
1
]
[b_{i0},b_{i1},b_{f0},b_{f1},b_{c0},b_{c1},b_{o0},b_{o1}]
[bi0,bi1,bf0,bf1,bc0,bc1,bo0,bo1] 。
五、基于TF模型参数的前向算法(C实现)
coef.h文件中存储了模型参数,如下:
#ifndef COEFS_H__
#define COEFS_H__
float lstm_weights[28][8] =
{
{-0.12686221, 0.26724398, 0.37801254, 0.40262774, -0.07338494,
0.32928947, -0.14913937, -0.28241915},
{ 0.33461186, -0.21897568, -0.06927446, 0.4203136 , -0.1555764 ,
0.02720599, -0.0151168 , -0.3075028 },
{-0.20676394, 0.24815308, -0.01202544, 0.18466435, -0.14753465,
0.08536137, 0.36230126, 0.15645675},
{-0.08381846, 0.16599338, -0.17138797, 0.14833589, -0.42522797,
0.21655686, 0.39841136, -0.41007423},
{ 0.18537116, -0.17786646, 0.5218234 , 0.6530849 , 0.11405347,
0.503562 , -0.02169047, -0.13635898},
{-0.14589112, 0.10930037, 0.20093921, 0.04657064, 0.05063201,
0.59452593, 0.2864187 , -0.26895872},
{ 0.28542188, -0.13795187, 0.4809328 , 0.16436823, 0.1529425 ,
0.33623406, -0.5526711 , -0.11713465},
{ 0.03136083, 0.20362344, 0.48479402, 0.33098167, 0.23541297,
0.5498583 , 0.12794426, 0.19516873},
{ 0.06115872, 0.5377143 , 0.68757766, -0.02475011, -0.31547132,
-0.22972348, 0.21856314, 0.2866804 },
{-0.06472135, 0.5002736 , 0.58080274, 0.5080287 , 0.32763457,
0.44843873, -0.02994427, 0.79296046},
{ 0.10341086, 0.6132778 , 0.1726614 , 0.33635128, -0.09166758,
0.31993377, 0.00283198, 0.53131664},
{ 0.08927457, 1.1540079 , 0.34614214, 0.13580003, 0.13172196,
0.04891009, 0.04630767, 0.24348891},
{ 0.22247258, 1.1757419 , -0.13906561, 0.2936401 , -0.33713883,
-0.25261644, -0.03830051, 0.37149888},
{ 0.20530753, 1.034049 , -0.05202543, -0.350264 , -0.39396283,
0.16860706, 0.20226672, 0.16733885},
{ 0.25870234, 0.2164095 , -0.02035809, -0.06610312, -0.3059797 ,
-0.21448267, 0.00927329, -0.435068 },
{ 0.45796677, 0.04458171, -0.3398304 , -0.3472258 , -0.36355677,
0.2214003 , -0.1399886 , -0.7422219 },
{-0.0643818 , -0.51386756, 0.24997349, -0.40613016, -0.3639028 ,
0.22827196, 0.04702655, -0.20871094},
{-0.31210366, -0.70306814, -0.43561232, -0.7216265 , -0.38581905,
0.5686793 , -0.61060965, -0.07557922},
{-0.5796511 , -0.32600936, -0.01578644, -0.33506516, 0.35982788,
0.15735114, -0.517505 , 0.2061426 },
{-0.3442457 , -0.10941461, -0.7360151 , -0.48871848, 0.20767467,
0.16316976, -0.55064833, -0.3612312 },
{ 0.00775149, -0.3448457 , -0.28391284, -0.45531642, 0.6768015 ,
0.36520806, -0.37463167, -0.04423084},
{ 0.4180928 , 0.18397227, -0.42225787, 0.14224927, 0.5412814 ,
-0.11046116, -0.9047872 , 0.00905087},
{ 0.55453503, 0.00336298, -0.21684031, 0.25182927, 0.45167074,
0.51548773, -0.89206326, 0.5669926 },
{ 0.7232391 , 0.8021931 , -0.20340312, 0.26141882, 0.8312971 ,
0.6391402 , -0.47316524, 0.32208198},
{ 0.78522533, 0.9301855 , -0.47462672, 0.01434673, 0.15459308,
0.0231053 , -0.5864275 , 0.5246032 },
{ 0.35548374, 0.272409 , -0.24058025, -0.14529192, 0.44297072,
0.3809 , 0.00734484, 0.3514555 },
{ 0.1273268 , 0.59552336, -0.07034865, -0.24836257, 0.34519887,
0.29520357, -0.23103833, 0.14769706},
{ 0.29127362, 0.22764683, -0.13500527, -0.0964766 , 0.40833357,
-0.18485136, -0.01389134, 0.3904065 }
};
float lstm_reweights[2][8] =
{
{ 0.47104427, -0.8412283 , -0.18447576, -0.1563985 , 0.33188817,
0.45603213, -0.7687683 , -1.0958537 },
{ 0.31553912, -0.43891597, 0.18234205, 0.44790304, 0.49662977,
0.7006842 , 0.01229569, 0.8207902 }
};
float lstm_bias[8] =
{ 0.4637353 , -0.14936855, 1.3707707 , 1.1197791 ,
-0.01564352, -0.05514838, 0.924866 , 0.9420725 };
float dense_weights[2][10] =
{
{ 1.7468047 , -1.1363125 , 1.3102428 , -0.16758148, 0.02137489,
0.01299422, 0.28959563, -0.57493824, -0.72183365, -0.7659956 },
{ 0.77278185, -1.2110348 , 0.7480287 , 0.7589525 , -0.45125306,
0.9120811 , 0.4531371 , -1.0628165 , 0.6895193 , -0.62167704},
};
float dense_bias[10] =
{-0.08420861, -0.17818679, 0.00183412, 0.01860343, 0.25198904,
-0.1387404 , 0.11423122, 0.13001607, -0.09703284, -0.03688832};
#endif
lstm.h中
/**
* @file lstm.h
* @brief LSTM前向算法的实现
* Details
*
* @author 算法在世间
* @email guzyu12@sina.com
* @date
* @version 1.0
*/
#ifndef LSTM_H__
#define LSTM_H__
#include <math.h>
// LSTM循环层的前向算法
void recurrent_forward(float **x, int seq_len, int dim_len, int units_num,
float **weights, float **rec_weights,float *bias,
float *hidden_stat, float *carry_stat);
// 普通全连接层的前向算法
void dense_forward(float *x,int input_len, int units_num,float **weights ,float *bias, float *y);
#endif
lstm.c中是前向算法函数的实现
#include "lstm.h"
// sigmoid函数
double sigmoid(double x) { return 1.0 / (1 + exp(-x)); }
// LSTM循环层前向算法函数
void recurrent_forward(float **x, int seq_len, int dim_len, int units_num,
float **weights, float **rec_weights,float *bias,
float *hidden_stat, float *carry_stat)
{
float *h_tm1 = (float *)calloc(units_num,sizeof(float));
float *c_tm1 = (float *)calloc(units_num,sizeof(float));
// 循环层按每一步骤计算
int step = 0;for(; step < seq_len; step++)
{
float *fv = (float*)x + step * dim_len;
memcpy(h_tm1,hidden_stat,sizeof(float)*units_num);
memcpy(c_tm1,carry_stat,sizeof(float)*units_num);
for(int j = 0; j < units_num; j++)
{
// 计算输入数据与各门的加权和部分
double zi = 0.0, zf = 0.0, zc = 0.0, zo = 0.0;
for(int k = 0; k < dim_len; k++){
zi += fv[k] * (((float *)weights + k*units_num*4)[j+units_num*0]);
zf += fv[k] * (((float *)weights + k*units_num*4)[j+units_num*1]);
zc += fv[k] * (((float *)weights + k*units_num*4)[j+units_num*2]);
zo += fv[k] * (((float *)weights + k*units_num*4)[j+units_num*3]);
}
// 计算隐含状态与各门的加权和部分
double ih = 0.0, fh = 0.0, ch= 0.0, oh = 0.0;
for(int k = 0; k < units_num; k++){
ih += h_tm1[k] * (((float *)rec_weights + k*units_num*4)[j+units_num*0]);
fh += h_tm1[k] * (((float *)rec_weights + k*units_num*4)[j+units_num*1]);
ch += h_tm1[k] * (((float *)rec_weights + k*units_num*4)[j+units_num*2]);
oh += h_tm1[k] * (((float *)rec_weights + k*units_num*4)[j+units_num*3]);
}
// 计算ifco的加权和
zi += (ih + bias[j+units_num*0]);
zf += (fh + bias[j+units_num*1]);
zc += (ch + bias[j+units_num*2]);
zo += (oh + bias[j+units_num*3]);
// 计算各gate的值
double i = sigmoid(zi);
double f = sigmoid(zf);
double c = f * c_tm1[j] + i * tanh(zc);
double o = sigmoid(zo);
double h = o * tanh(c);
// 更新隐藏状态和单元状态
hidden_stat[j] = h;
carry_stat[j] = c;
}
}
}
/**@brief 全连接层前向算法函数 */
void dense_forward(float *x,int input_len, int units_num,float **weights ,float *bias, float *y)
{
double sum = 0;
for(int k=0; k < units_num; k++){
sum = 0;
float s = 0.0;
for(int j = 0; j < input_len; j++){
// sum += (x[j] * (*((float*)weights+j*units_num+k)));
sum += (x[j] * (((float*)weights+j*units_num)[k]));
}
y[k] = (float)sigmoid(sum+bias[k]);
}
}
main.c
#include <stdio.h>
#include "src/lstm.h"
#include "src/coef.h"
float sample[28][28] = {
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.32941176,0.7254902,0.62352941,0.59215686,0.23529412,0.14117647,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.87058824,0.99607843,0.99607843,0.99607843,0.99607843,0.94509804,0.77647059,0.77647059,0.77647059,0.77647059,0.77647059,0.77647059,0.77647059,0.77647059,0.66666667,0.20392157,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.2627451,0.44705882,0.28235294,0.44705882,0.63921569,0.89019608,0.99607843,0.88235294,0.99607843,0.99607843,0.99607843,0.98039216,0.89803922,0.99607843,0.99607843,0.54901961,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.06666667,0.25882353,0.05490196,0.2627451,0.2627451,0.2627451,0.23137255,0.08235294,0.9254902,0.99607843,0.41568627,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.3254902,0.99215686,0.81960784,0.07058824,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.08627451,0.91372549,1.,0.3254902,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.50588235,0.99607843,0.93333333,0.17254902,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.23137255,0.97647059,0.99607843,0.24313725,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.52156863,0.99607843,0.73333333,0.01960784,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.03529412,0.80392157,0.97254902,0.22745098,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.49411765,0.99607843,0.71372549,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.29411765,0.98431373,0.94117647,0.22352941,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.0745098,0.86666667,0.99607843,0.65098039,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.01176471,0.79607843,0.99607843,0.85882353,0.1372549,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.14901961,0.99607843,0.99607843,0.30196078,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.12156863,0.87843137,0.99607843,0.45098039,0.00392157,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.52156863,0.99607843,0.99607843,0.20392157,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.23921569,0.94901961,0.99607843,0.99607843,0.20392157,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.4745098,0.99607843,0.99607843,0.85882353,0.15686275,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.4745098,0.99607843,0.81176471,0.07058824,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.},
{0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.}
};
void main(void)
{
int seq_len = 28; // 序列长度,由于我们将28*28的图片看成序列数据,则行数28即为序列长度
int dim_len = 28; // 特征维数,
int units_num = 2; // 循环层节点数目
float hidden_stat[2] = {0.0}; // 隐含状态数组的长度和节点数目一致
float carray_stat[2] = {0.0}; // 同上
float rslt_array[10] = {0.0}; // 数值识别0~9共10类
// 调用循环层前向算法函数
recurrent_forward((float **)sample,seq_len,dim_len,units_num,
(float**)lstm_weights,(float**)lstm_reweights,lstm_bias,
hidden_stat,carray_stat);
// 调用全连接层前向算法函数
dense_forward(hidden_stat, 2, 10, (float **)dense_weights, dense_bias, rslt_array);
// 打印输出
printf("\n{ ");
for(int i=0; i<10;i++) { printf("%f ",rslt_array[i]); }
printf("}\n");
}
最终的C实现的前向算法输出结果如下:
六、参考资料
[1] Understanding LSTM Networks
[2] 李宏毅深度学习教程
[3] 李宏毅机器学习视频