deeplearning.ai 总结 - C++实现lstm cell
flyfish
//各个变量的含义
// xt -- your input data at timestep "t", array of shape(n_x, m).
// a_prev -- Hidden state at timestep "t-1", array of shape(n_a, m)
// c_prev -- Memory state at timestep "t-1", array of shape(n_a, m)
// Wf -- Weight matrix of the forget gate, array of shape(n_a, n_a + n_x)
// bf -- Bias of the forget gate, array of shape(n_a, 1)
// Wi -- Weight matrix of the update gate, array of shape(n_a, n_a + n_x)
// bi -- Bias of the update gate, array of shape(n_a, 1)
// Wc -- Weight matrix of the first "tanh", array of shape(n_a, n_a + n_x)
// bc -- Bias of the first "tanh", array of shape(n_a, 1)
// Wo -- Weight matrix of the output gate, array of shape(n_a, n_a + n_x)
// bo -- Bias of the output gate, array of shape(n_a, 1)
// Wy -- Weight matrix relating the hidden - state to the output, array of shape(n_y, n_a)
// by -- Bias relating the hidden - state to the output, array of shape(n_y, 1)
//包含个生成随机数的各种方法,可挑选一种使用
//例如Eigen::MatrixXd Wf = Eigen::MatrixXd::Random(5,3);是比较简洁的一种方法
Eigen::Matrix<double, 3, 10> xt;
xt.setRandom();
std::cout << "xt1:\n" << xt << std::endl;
Eigen::Matrix<double, 5, 10> a_prev;
a_prev.setRandom();
std::cout << "a_prev:\n" << a_prev << std::endl;
Eigen::Matrix<double, 5, 10> c_prev = Eigen::MatrixXd::Random(5, 10);;
Eigen::MatrixXd Wf = Eigen::MatrixXd::Random(5, 5 + 3);
Eigen::VectorXd bf(5);//相当于Eigen::Matrix<double, 5, 1> bf;
bf.setRandom();
Eigen::Matrix<double, 5, 5 + 3> Wi = Eigen::MatrixXd::Random(5, 5 + 3);
Eigen::Matrix<double, 5, 1> bi = Eigen::MatrixXd::Random(5, 1);
Eigen::Matrix<double, 5, 5 + 3> Wo = Eigen::MatrixXd::Random(5, 5 + 3);
Eigen::Matrix<double, 5, 1> bo = Eigen::MatrixXd::Random(5, 1);
Eigen::Matrix<double, 5, 5 + 3> Wc = Eigen::MatrixXd::Random(5, 5 + 3);
Eigen::Matrix<double, 5, 1> bc = Eigen::MatrixXd::Random(5, 1);
Eigen::Matrix<double, 2, 5> Wy = Eigen::MatrixXd::Random(2, 5);
Eigen::Matrix<double, 2, 1> by = Eigen::MatrixXd::Random(2, 1);
Eigen::Matrix<double, 3 + 5, 10> concat;
concat << a_prev, xt;
std::cout << "concat a_prev xt:\n" << concat << std::endl;
Eigen::MatrixXd ft = sigmond_forward(matrix_add_bias(Wf * concat, bf));
std::cout << "ft:\n" << ft << std::endl;
Eigen::MatrixXd it = sigmond_forward(matrix_add_bias(Wi * concat, bi));
std::cout << "it:\n" << it << std::endl;
Eigen::MatrixXd cct = (matrix_add_bias(Wc * concat, bc)).array().tanh();
std::cout << "cct:\n" << cct << std::endl;
Eigen::MatrixXd c_next = it.cwiseProduct(cct) + ft.cwiseProduct(c_prev);
std::cout << "c_next:\n" << c_next << std::endl;
//cwiseProduct()函数允许Matrix直接进行点对点乘法,而不用转换至Array。
Eigen::MatrixXd ot = sigmond_forward(matrix_add_bias(Wo * concat, bo));
Eigen::MatrixXd t = (c_next.array().tanh());
Eigen::MatrixXd a_next = ot.cwiseProduct(t);
std::cout << "a_next:\n" << a_next << std::endl;
Eigen::MatrixXd yt_pred = softmax_forward(matrix_add_bias(Wy*a_next, by));
std::cout << "yt_pred:\n" << a_next << std::endl;
下面包含使用函数的实现
Eigen::MatrixXd matrix_add_bias(const Eigen::MatrixXd & x, const Eigen::VectorXd& b)
{
int rows = x.rows();
int cols = x.cols();
Eigen::MatrixXd res(rows, cols);
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
res(i, j) = (x)(i, j) + b(i);
}
}
//std::cout << "matrix_add_bias \n" << res << std::endl;
return res;
}
Eigen::MatrixXd sigmond_forward(const Eigen::MatrixXd &x)
{
int rows = x.rows();
int cols = x.cols();
Eigen::MatrixXd res(rows, cols);
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
res(i, j) = double(1.0) / double((1.0) + std::exp(-1.0 * x(i, j)));
}
}
return res;
}
Eigen::MatrixXd softmax_forward(const Eigen::MatrixXd &x)
{
int rows = x.rows();
int cols = x.cols();
Eigen::MatrixXd res(rows, cols);
//数据预处理
double max_value=x.array().maxCoeff();
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
res(i, j) = x(i, j) - max_value;//各个元素减去矩阵的最大值
}
}
//数据预处理好之后执行公式
res = res.array().exp();
Eigen::VectorXd col_sum(cols);//计算各列的和
for (int j=0;j<cols;j++)
{
col_sum(j)=res.col(j).sum();
}
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
res(i, j) = res(i, j) / col_sum(j);//公式
}
}
return res;
}
为了测试C++代码的正确性,用Python写然后全部输出数据,查看数据是否一致
Python代码是 主要deeplearning.ai的代码
import numpy as np
np.random.seed(1);
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4)
Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the update gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the update gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the output gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the output gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)
Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters)
Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilde),
c stands for the memory value
"""
# Retrieve parameters from "parameters"
Wf = parameters["Wf"]
bf = parameters["bf"]
Wi = parameters["Wi"]
bi = parameters["bi"]
Wc = parameters["Wc"]
bc = parameters["bc"]
Wo = parameters["Wo"]
bo = parameters["bo"]
Wy = parameters["Wy"]
by = parameters["by"]
# Retrieve dimensions from shapes of xt and Wy
n_x, m = xt.shape
n_y, n_a = Wy.shape
### START CODE HERE ###
# Concatenate a_prev and xt (≈3 lines)
# xt -- your input data at timestep "t", numpy array of shape (n_x, m).
# a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
concat = np.zeros([n_x+n_a,m])
concat[: n_a, :] = a_prev
concat[n_a :, :] = xt
# Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
ft = sigmoid(np.dot(Wf,concat)+bf)
print("ft:\n",ft)
it = sigmoid(np.dot(Wi,concat)+bi)
print("it:\n",it)
cct = np.tanh(np.dot(Wc,concat)+bc)
print("cct:\n",cct)
c_next = ft*c_prev + it*cct
print("c_next:\n",c_next)
ot = sigmoid(np.dot(Wo,concat)+bo)
print("ot:\n",ot)
a_next =ot* np.tanh(c_next)
print("a_next:\n",a_next)
# Compute prediction of the LSTM cell (≈1 line)
yt_pred = softmax(a_next)
print("yt_pred:\n",yt_pred)
### END CODE HERE ###
# store values needed for backward propagation in cache
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters)
return a_next, c_next, yt_pred, cache
#xt = np.random.randn(3,10)
xt= np.array([[ 1.62434536 ,-0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387,1.74481176, -0.7612069, 0.3190391, -0.24937038],
[ 1.46210794 ,-2.06014071 ,-0.3224172 ,-0.38405435 , 1.13376944 ,-1.09989127,-0.17242821 ,-0.87785842 ,0.04221375 ,0.58281521],
[-1.10061918 ,1.14472371 , 0.90159072 ,0.50249434 , 0.90085595 ,-0.68372786,-0.12289023 ,-0.93576943, -0.26788808 ,0.53035547]])
print("xt:\n",xt)
a_prev = np.random.randn(5,10)
print("a_prev:\n",a_prev)
c_prev = np.random.randn(5,10)
print("c_prev:\n",c_prev)
Wf = np.random.randn(5, 5+3)
print("Wf:\n",Wf)
bf = np.random.randn(5,1)
print("bf:\n",bf)
Wi = np.random.randn(5, 5+3)
print("Wi:\n",Wi)
bi = np.random.randn(5,1)
print("bi:\n",bi)
Wo = np.random.randn(5, 5+3)
print("Wo:\n",Wo)
bo = np.random.randn(5,1)
print("bo:\n",bo)
Wc = np.random.randn(5, 5+3)
print("Wc:\n",Wc)
bc = np.random.randn(5,1)
print("bc:\n",bc)
Wy = np.random.randn(2,5)
print("Wy:\n",Wy)
by = np.random.randn(2,1)
print("by:\n",by)
parameters = {"Wf": Wf, "Wi": Wi, "Wo": Wo, "Wc": Wc, "Wy": Wy, "bf": bf, "bi": bi, "bo": bo, "bc": bc, "by": by}
a_next, c_next, yt, cache = lstm_cell_forward(xt, a_prev, c_prev, parameters)
print("a_next[4] = ", a_next[4])
print("a_next.shape = ", c_next.shape)
print("c_next[2] = ", c_next[2])
print("c_next.shape = ", c_next.shape)
print("yt[1] =", yt[1])
print("yt.shape = ", yt.shape)
print("cache[1][3] =", cache[1][3])
print("len(cache) = ", len(cache))
输出数据是
xt:
[[ 1.62434536 -0.61175641 -0.52817175 -1.07296862 0.86540763 -2.3015387
1.74481176 -0.7612069 0.3190391 -0.24937038]
[ 1.46210794 -2.06014071 -0.3224172 -0.38405435 1.13376944 -1.09989127
-0.17242821 -0.87785842 0.04221375 0.58281521]
[-1.10061918 1.14472371 0.90159072 0.50249434 0.90085595 -0.68372786
-0.12289023 -0.93576943 -0.26788808 0.53035547]]
a_prev:
[[ 1.62434536 -0.61175641 -0.52817175 -1.07296862 0.86540763 -2.3015387
1.74481176 -0.7612069 0.3190391 -0.24937038]
[ 1.46210794 -2.06014071 -0.3224172 -0.38405435 1.13376944 -1.09989127
-0.17242821 -0.87785842 0.04221375 0.58281521]
[-1.10061918 1.14472371 0.90159072 0.50249434 0.90085595 -0.68372786
-0.12289023 -0.93576943 -0.26788808 0.53035547]
[-0.69166075 -0.39675353 -0.6871727 -0.84520564 -0.67124613 -0.0126646
-1.11731035 0.2344157 1.65980218 0.74204416]
[-0.19183555 -0.88762896 -0.74715829 1.6924546 0.05080775 -0.63699565
0.19091548 2.10025514 0.12015895 0.61720311]]
c_prev:
[[ 0.30017032 -0.35224985 -1.1425182 -0.34934272 -0.20889423 0.58662319
0.83898341 0.93110208 0.28558733 0.88514116]
[-0.75439794 1.25286816 0.51292982 -0.29809284 0.48851815 -0.07557171
1.13162939 1.51981682 2.18557541 -1.39649634]
[-1.44411381 -0.50446586 0.16003707 0.87616892 0.31563495 -2.02220122
-0.30620401 0.82797464 0.23009474 0.76201118]
[-0.22232814 -0.20075807 0.18656139 0.41005165 0.19829972 0.11900865
-0.67066229 0.37756379 0.12182127 1.12948391]
[ 1.19891788 0.18515642 -0.37528495 -0.63873041 0.42349435 0.07734007
-0.34385368 0.04359686 -0.62000084 0.69803203]]
Wf:
[[-0.44712856 1.2245077 0.40349164 0.59357852 -1.09491185 0.16938243
0.74055645 -0.9537006 ]
[-0.26621851 0.03261455 -1.37311732 0.31515939 0.84616065 -0.85951594
0.35054598 -1.31228341]
[-0.03869551 -1.61577235 1.12141771 0.40890054 -0.02461696 -0.77516162
1.27375593 1.96710175]
[-1.85798186 1.23616403 1.62765075 0.3380117 -1.19926803 0.86334532
-0.1809203 -0.60392063]
[-1.23005814 0.5505375 0.79280687 -0.62353073 0.52057634 -1.14434139
0.80186103 0.0465673 ]]
bf:
[[-0.18656977]
[-0.10174587]
[ 0.86888616]
[ 0.75041164]
[ 0.52946532]]
Wi:
[[ 0.13770121 0.07782113 0.61838026 0.23249456 0.68255141 -0.31011677
-2.43483776 1.0388246 ]
[ 2.18697965 0.44136444 -0.10015523 -0.13644474 -0.11905419 0.01740941
-1.12201873 -0.51709446]
[-0.99702683 0.24879916 -0.29664115 0.49521132 -0.17470316 0.98633519
0.2135339 2.19069973]
[-1.89636092 -0.64691669 0.90148689 2.52832571 -0.24863478 0.04366899
-0.22631424 1.33145711]
[-0.28730786 0.68006984 -0.3198016 -1.27255876 0.31354772 0.50318481
1.29322588 -0.11044703]]
bi:
[[-0.61736206]
[ 0.5627611 ]
[ 0.24073709]
[ 0.28066508]
[-0.0731127 ]]
Wo:
[[ 1.16033857 0.36949272 1.90465871 1.1110567 0.6590498 -1.62743834
0.60231928 0.4202822 ]
[ 0.81095167 1.04444209 -0.40087819 0.82400562 -0.56230543 1.95487808
-1.33195167 -1.76068856]
[-1.65072127 -0.89055558 -1.1191154 1.9560789 -0.3264995 -1.34267579
1.11438298 -0.58652394]
[-1.23685338 0.87583893 0.62336218 -0.43495668 1.40754 0.12910158
1.6169496 0.50274088]
[ 1.55880554 0.1094027 -1.2197444 2.44936865 -0.54577417 -0.19883786
-0.7003985 -0.20339445]]
bo:
[[ 0.24266944]
[ 0.20183018]
[ 0.66102029]
[ 1.79215821]
[-0.12046457]]
Wc:
[[-1.23312074e+00 -1.18231813e+00 -6.65754518e-01 -1.67419581e+00
8.25029824e-01 -4.98213564e-01 -3.10984978e-01 -1.89148284e-03]
[-1.39662042e+00 -8.61316361e-01 6.74711526e-01 6.18539131e-01
-4.43171931e-01 1.81053491e+00 -1.30572692e+00 -3.44987210e-01]
[-2.30839743e-01 -2.79308500e+00 1.93752881e+00 3.66332015e-01
-1.04458938e+00 2.05117344e+00 5.85662000e-01 4.29526140e-01]
[-6.06998398e-01 1.06222724e-01 -1.52568032e+00 7.95026094e-01
-3.74438319e-01 1.34048197e-01 1.20205486e+00 2.84748111e-01]
[ 2.62467445e-01 2.76499305e-01 -7.33271604e-01 8.36004719e-01
1.54335911e+00 7.58805660e-01 8.84908814e-01 -8.77281519e-01]]
bc:
[[-0.86778722]
[-1.44087602]
[ 1.23225307]
[-0.25417987]
[ 1.39984394]]
Wy:
[[-0.78191168 -0.43750898 0.09542509 0.92145007 0.0607502 ]
[ 0.21112476 0.01652757 0.17718772 -1.11647002 0.0809271 ]]
by:
[[-0.18657899]
[-0.05682448]]
ft:
[[0.93342112 0.01873532 0.31879297 0.03645608 0.70083348 0.34466958
0.14007722 0.03403372 0.69185905 0.62265514]
[0.76944697 0.01548017 0.05212295 0.6846647 0.0380791 0.96645775
0.11998371 0.99169665 0.71083407 0.43393576]
[0.0096553 0.99579174 0.9807537 0.95426937 0.90749184 0.73438514
0.20861643 0.2578163 0.60901046 0.9426081 ]
[0.38869392 0.78367665 0.92568891 0.40404883 0.84818819 0.8742035
0.13017358 0.05644821 0.65028999 0.84117879]
[0.12537878 0.48551632 0.8950891 0.98772175 0.77013915 0.97363848
0.04094597 0.78774353 0.20290836 0.90145966]]
it:
[[1.56339072e-03 9.96099904e-01 7.42313578e-01 9.05874735e-01
1.12309515e-01 6.90213408e-01 3.00505508e-01 8.20579645e-01
3.21278251e-01 3.83345996e-01]
[9.81012966e-01 5.17414473e-01 3.19542854e-01 1.25933181e-01
7.73509903e-01 3.68808714e-02 9.91267605e-01 4.44691623e-01
7.61676248e-01 2.91990904e-01]
[1.83278673e-01 8.05585226e-01 8.31349370e-01 5.77414603e-01
8.92894351e-01 1.92603857e-01 3.36918546e-01 1.01540368e-01
6.34230205e-01 8.55328961e-01]
[2.84591243e-04 9.93218262e-01 8.81040869e-01 7.62818965e-01
1.18062012e-01 9.83774288e-01 2.60383249e-03 6.07897288e-01
9.61623854e-01 9.58539298e-01]
[9.88527761e-01 1.05985672e-02 3.60990517e-01 5.81138043e-01
9.44460008e-01 6.72442596e-02 8.17257718e-01 2.30253916e-01
1.32365782e-01 5.11039172e-01]]
cct:
[[-0.99948169 0.98478327 0.43025665 0.99937994 -0.99815922 0.99983835
-0.90887136 0.99884035 -0.99930754 -0.98346529]
[-0.99981448 0.9972597 -0.63957201 -0.97374353 -0.99872911 0.03899929
-0.8214443 -0.73883911 -0.47400845 -0.9851213 ]
[-0.93498283 0.99999984 0.99863728 -0.64181366 0.97259978 -0.90278087
0.99916331 -0.98760363 0.91674587 0.35773607]
[ 0.94449743 -0.99940948 -0.9479902 -0.97468636 -0.69147317 0.44003728
-0.97115684 -0.44742202 0.85028554 0.31577646]
[ 0.99997551 -0.99993444 -0.99089385 0.7416703 0.93112785 -0.9716276
0.98763731 0.99977554 0.99895712 0.98765691]]
c_next:
[[ 0.27862274 0.974343 -0.04484141 0.89257737 -0.25850285 0.892293
-0.15559838 0.85131693 -0.1234696 0.17413022]
[-1.56130017 0.53539121 -0.17763525 -0.32672026 -0.75392452 -0.07159854
-0.67849403 1.17864168 1.19254048 -0.89363616]
[-0.18530577 0.30324215 0.98717342 0.46550858 1.15486499 -1.6589536
0.27275746 0.11318372 0.72155802 1.02425993]
[-0.0861488 -1.14996116 -0.6625203 -0.57782835 0.08655877 0.53693514
-0.08983124 -0.25067383 0.89687402 1.25278205]
[ 1.13882241 0.07929859 -0.69361675 -0.19987509 1.2055626 0.00996489
0.79307479 0.26454537 0.00642439 1.13397909]]
ot:
[[0.07249733 0.54040145 0.73424974 0.84751397 0.91071347 0.14496133
0.10515789 0.31305591 0.80770578 0.96738433]
[0.99795666 0.03916275 0.0369245 0.00454259 0.43548428 0.01760292
0.98680806 0.35063646 0.95032392 0.17690516]
[0.03601709 0.39948237 0.38506047 0.67347204 0.01054133 0.99982554
0.00130693 0.98388906 0.96706885 0.86799202]
[0.91900679 0.07919111 0.86248002 0.99522256 0.99355306 0.48471955
0.51145512 0.90443733 0.66585323 0.99066356]
[0.7687726 0.13571441 0.03891578 0.00629472 0.07126004 0.21235694
0.42270364 0.53065266 0.99071107 0.48047872]]
a_next:
[[ 1.96924439e-02 4.05628887e-01 -3.29027460e-02 6.03993044e-01
-2.30314578e-01 1.03288489e-01 -1.62316145e-02 2.16558564e-01
-9.92234125e-02 1.66768691e-01]
[-9.13759893e-01 1.91698281e-02 -6.49096391e-03 -1.43350878e-03
-2.77614447e-01 -1.25819395e-03 -5.82749126e-01 2.89984356e-01
7.90066150e-01 -1.26166120e-01]
[-6.59881685e-03 1.17558399e-01 2.91165199e-01 2.92666167e-01
8.63713827e-03 -9.29913841e-01 3.47891416e-04 1.10887123e-01
5.97526168e-01 6.69739123e-01]
[-7.89760541e-02 -6.47578349e-02 -5.00271371e-01 -5.18595781e-01
8.57865859e-02 2.37834666e-01 -4.58214583e-02 -2.22086367e-01
4.75933484e-01 8.41134712e-01]
[ 6.25794109e-01 1.07394604e-02 -2.33611599e-02 -1.24166589e-03
5.95266100e-02 2.11604307e-03 2.79046697e-01 1.37196019e-01
6.36462524e-03 3.90329719e-01]]
yt_pred:
[[0.19582658 0.2682632 0.19818694 0.31749071 0.16858565 0.2323208
0.20414163 0.22004507 0.12005679 0.15121714]
[0.07699782 0.18227372 0.20349115 0.17329962 0.16079722 0.20925895
0.11585003 0.23681001 0.29214623 0.11281869]
[0.19074515 0.20111926 0.27404125 0.23255382 0.21408982 0.08267496
0.20755441 0.19797899 0.24098013 0.25005659]
[0.17742731 0.16760028 0.12419366 0.10332297 0.23126058 0.26577905
0.19818961 0.1419092 0.21339007 0.29680722]
[0.35900314 0.18074354 0.20008701 0.17333287 0.22526672 0.20996623
0.27426432 0.20325672 0.13342679 0.18910036]]
再把数据填到C++中然看输出
Eigen::Matrix<double, 3, 10> xt;
xt << 1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, 0.3190391, -0.24937038,
1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, 0.04221375, 0.58281521,
- 1.10061918, 1.14472371, 0.90159072, 0.50249434, 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547;
Eigen::Matrix<double, 5, 10> a_prev;
a_prev << 1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, 0.3190391, -0.24937038,
1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, 0.04221375, 0.58281521,
-1.10061918, 1.14472371, 0.90159072, 0.50249434, 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547,
-0.69166075, -0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646, -1.11731035, 0.2344157, 1.65980218, 0.74204416,
-0.19183555, -0.88762896, -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514, 0.12015895, 0.61720311;
Eigen::Matrix<double, 5, 10> c_prev;
c_prev<<0.30017032, -0.35224985, -1.1425182, -0.34934272, -0.20889423, 0.58662319, 0.83898341, 0.93110208, 0.28558733, 0.88514116,
-0.75439794, 1.25286816, 0.51292982, -0.29809284, 0.48851815, -0.07557171, 1.13162939, 1.51981682, 2.18557541, -1.39649634,
-1.44411381, -0.50446586, 0.16003707, 0.87616892, 0.31563495, -2.02220122, -0.30620401, 0.82797464, 0.23009474, 0.76201118,
-0.22232814, -0.20075807, 0.18656139, 0.41005165, 0.19829972, 0.11900865, -0.67066229, 0.37756379, 0.12182127, 1.12948391,
1.19891788, 0.18515642, -0.37528495, -0.63873041, 0.42349435, 0.07734007, -0.34385368, 0.04359686, -0.62000084, 0.69803203;
Eigen::Matrix<double, 5, 5 + 3> Wf;
Wf <<-0.44712856, 1.2245077 , 0.40349164, 0.59357852, -1.09491185, 0.16938243, 0.74055645, -0.9537006,
-0.26621851, 0.03261455, -1.37311732, 0.31515939, 0.84616065, -0.85951594, 0.35054598, -1.31228341,
-0.03869551, -1.61577235, 1.12141771, 0.40890054, -0.02461696, -0.77516162, 1.27375593, 1.96710175,
-1.85798186, 1.23616403, 1.62765075, 0.3380117, -1.19926803, 0.86334532, -0.1809203, -0.60392063,
-1.23005814, 0.5505375 , 0.79280687, -0.62353073, 0.52057634, -1.14434139, 0.80186103, 0.0465673;
Eigen::VectorXd bf(5);//相当于Eigen::Matrix<double, 5, 1> bf;
bf << -0.18656977,
-0.10174587,
0.86888616,
0.75041164,
0.52946532;
Eigen::Matrix<double, 5, 5 + 3> Wi;
Wi<<0.13770121, 0.07782113, 0.61838026, 0.23249456, 0.68255141, -0.31011677, -2.43483776, 1.0388246,
2.18697965, 0.44136444, -0.10015523, -0.13644474, -0.11905419, 0.01740941, -1.12201873, -0.51709446,
-0.99702683, 0.24879916, -0.29664115, 0.49521132, -0.17470316, 0.98633519, 0.2135339, 2.19069973,
-1.89636092, -0.64691669, 0.90148689, 2.52832571, -0.24863478, 0.04366899, -0.22631424, 1.33145711,
-0.28730786, 0.68006984, -0.3198016, -1.27255876, 0.31354772, 0.50318481, 1.29322588, -0.11044703;
Eigen::Matrix<double, 5, 1> bi;
bi << -0.61736206,
0.5627611,
0.24073709,
0.28066508,
-0.0731127;
Eigen::Matrix<double, 5, 5 + 3> Wo ;
Wo<<1.16033857, 0.36949272, 1.90465871, 1.1110567 , 0.6590498, -1.62743834, 0.60231928, 0.4202822,
0.81095167, 1.04444209, -0.40087819, 0.82400562, -0.56230543, 1.95487808, -1.33195167, -1.76068856,
-1.65072127, -0.89055558, -1.1191154, 1.9560789, -0.3264995, -1.34267579, 1.11438298, -0.58652394,
-1.23685338, 0.87583893, 0.62336218, -0.43495668, 1.40754 , 0.12910158, 1.6169496 , 0.50274088,
1.55880554, 0.1094027, -1.2197444 , 2.44936865, -0.54577417, -0.19883786, -0.7003985, -0.20339445;
Eigen::Matrix<double, 5, 1> bo ;
bo<<0.24266944,
0.20183018,
0.66102029,
1.79215821,
-0.12046457;
Eigen::Matrix<double, 5, 5 + 3> Wc ;
Wc <<-1.23312074e+00, -1.18231813e+00, -6.65754518e-01, -1.67419581e+00, 8.25029824e-01, -4.98213564e-01, -3.10984978e-01, -1.89148284e-03,
-1.39662042e+00, -8.61316361e-01, 6.74711526e-01, 6.18539131e-01, -4.43171931e-01, 1.81053491e+00, -1.30572692e+00, -3.44987210e-01,
-2.30839743e-01, -2.79308500e+00, 1.93752881e+00, 3.66332015e-01, -1.04458938e+00, 2.05117344e+00, 5.85662000e-01, 4.29526140e-01,
-6.06998398e-01, 1.06222724e-01, -1.52568032e+00, 7.95026094e-01, -3.74438319e-01, 1.34048197e-01, 1.20205486e+00, 2.84748111e-01,
2.62467445e-01, 2.76499305e-01, -7.33271604e-01, 8.36004719e-01, 1.54335911e+00, 7.58805660e-01, 8.84908814e-01, -8.77281519e-01;
Eigen::Matrix<double, 5, 1> bc;
bc << -0.86778722,
-1.44087602,
1.23225307,
-0.25417987,
1.39984394;
Eigen::Matrix<double, 2, 5> Wy ;
Wy << -0.78191168, -0.43750898, 0.09542509, 0.92145007, 0.0607502,
0.21112476, 0.01652757, 0.17718772, -1.11647002, 0.0809271;
Eigen::Matrix<double, 2, 1> by = Eigen::MatrixXd::Random(2, 1);
by <<-0.18657899,
- 0.05682448;
Eigen::Matrix<double, 3 + 5, 10> concat;
concat << a_prev, xt;
std::cout << "concat a_prev xt:\n" << concat << std::endl;
Eigen::MatrixXd ft = sigmond_forward(matrix_add_bias(Wf * concat, bf));
std::cout << "ft:\n" << ft << std::endl;
Eigen::MatrixXd it = sigmond_forward(matrix_add_bias(Wi * concat, bi));
std::cout << "it:\n" << it << std::endl;
Eigen::MatrixXd cct = (matrix_add_bias(Wc * concat, bc)).array().tanh();
std::cout << "cct:\n" << cct << std::endl;
Eigen::MatrixXd c_next = it.cwiseProduct(cct) + ft.cwiseProduct(c_prev);
std::cout << "c_next:\n" << c_next << std::endl;
Eigen::MatrixXd ot = sigmond_forward(matrix_add_bias(Wo * concat, bo));
std::cout << "ot:\n" << ot << std::endl;
Eigen::MatrixXd t = (c_next.array().tanh());
Eigen::MatrixXd a_next = ot.cwiseProduct(t);
std::cout << "a_next:\n" << a_next << std::endl;
Eigen::MatrixXd yt_pred = softmax_forward(a_next);
std::cout << "yt_pred:\n" << yt_pred << std::endl;
结果是C++代码与Python代码输出结果是一致的