RNN. LSTM matlab

最近在学习RNN和LSTM,

(1): http://magicly.me/2017/03/09/iamtrask-anyone-can-code-lstm/

(2):  https://zybuluo.com/hanbingtao/note/581764

(3): http://blog.sina.com.cn/s/blog_a5fdbf010102w7y8.html

在资料1中给出了RNN的python代码,广为流传,并且被(3)翻译成了matlab代码。网址(2)是个很好的理论推导的网站,强烈推荐。想学习的可以缘木求鱼,根据上述资料进行学习。

但是在(1)中提及到的python源码作者说过段时间在twitter上更新LSTM代码,目前还未更新;(3)的作者想根据RNN写LSTM,但是发现并未运行成功,于是我在(2)的基础上对其进行修改,并运行成功,下面是指出作者(2)中的错误:

1:作者在求H_t_diff时有问题,不应该乘以那么导数,因为从后面传过来的误差是输出层(不是输出门)的误差乘以权值矩阵就可以了。

2:求各个门的误差都不用成乘激活函数

3:各个门中均未加偏置


如果你运行成功,还望点个赞。哈哈!


以上错误可以根据我下面给出的源码和(3)中的源码进行比较。

 

下面是我修改后的源代码:

[plain]  view plain  copy
  1. %接下来就是LSTM的Matlab代码,我也进行了注释,用英文注释的,也比较容易懂:  
  2. % implementation of LSTM  
  3. clc  
  4. % clear  
  5. close all  
  6.   
  7.   
  8. %% training dataset generation  
  9. binary_dim     = 8;  
  10.   
  11.   
  12. largest_number = 2^binary_dim - 1;  
  13. binary         = cell(largest_number, 1);  
  14.   
  15.   
  16. for i = 1:largest_number + 1  
  17.     binary{i}      = dec2bin(i-1, binary_dim);  
  18.     int2binary{i}  = binary{i};  
  19. end  
  20.   
  21.   
  22. %% input variables  
  23. alpha      = 0.1;  
  24. input_dim  = 2;  
  25. hidden_dim = 32;  
  26. output_dim = 1;  
  27. allErr = [];  
  28. %% initialize neural network weights  
  29. % in_gate     = sigmoid(X(t) * X_i + H(t-1) * H_i)    ------- (1)  
  30. X_i = 2 * rand(input_dim, hidden_dim) - 1;  
  31. H_i = 2 * rand(hidden_dim, hidden_dim) - 1;  
  32. X_i_update = zeros(size(X_i));  
  33. H_i_update = zeros(size(H_i));  
  34. bi = 2*rand(1,1) - 1;  
  35. bi_update = 0;  
  36.   
  37.   
  38. % forget_gate = sigmoid(X(t) * X_f + H(t-1) * H_f)    ------- (2)  
  39. X_f = 2 * rand(input_dim, hidden_dim) - 1;  
  40. H_f = 2 * rand(hidden_dim, hidden_dim) - 1;  
  41. X_f_update = zeros(size(X_f));  
  42. H_f_update = zeros(size(H_f));  
  43. bf = 2*rand(1,1) - 1;  
  44. bf_update = 0;  
  45. % out_gate    = sigmoid(X(t) * X_o + H(t-1) * H_o)    ------- (3)  
  46. X_o = 2 * rand(input_dim, hidden_dim) - 1;  
  47. H_o = 2 * rand(hidden_dim, hidden_dim) - 1;  
  48. X_o_update = zeros(size(X_o));  
  49. H_o_update = zeros(size(H_o));  
  50. bo = 2*rand(1,1) - 1;  
  51. bo_update = 0;  
  52. % g_gate      = tanh(X(t) * X_g + H(t-1) * H_g)       ------- (4)  
  53. X_g = 2 * rand(input_dim, hidden_dim) - 1;  
  54. H_g = 2 * rand(hidden_dim, hidden_dim) - 1;  
  55. X_g_update = zeros(size(X_g));  
  56. H_g_update = zeros(size(H_g));  
  57. bg = 2*rand(1,1) - 1;  
  58. bg_update = 0;  
  59.   
  60.   
  61. out_para = 2 * rand(hidden_dim, output_dim) - 1;  
  62. out_para_update = zeros(size(out_para));  
  63. % C(t) = C(t-1) .* forget_gate + g_gate .* in_gate    ------- (5)  
  64. % S(t) = tanh(C(t)) .* out_gate                       ------- (6)  
  65. % Out  = sigmoid(S(t) * out_para)                     ------- (7)  
  66. % Note: Equations (1)-(6) are cores of LSTM in forward, and equation (7) is  
  67. % used to transfer hiddent layer to predicted output, i.e., the output layer.  
  68. % (Sometimes you can use softmax for equation (7))  
  69.   
  70.   
  71. %% train   
  72. iter = 99999; % training iterations  
  73. for j = 1:iter  
  74.     % generate a simple addition problem (a + b = c)  
  75.     a_int = randi(round(largest_number/2));   % int version  
  76.     a     = int2binary{a_int+1};              % binary encoding  
  77.       
  78.     b_int = randi(floor(largest_number/2));   % int version  
  79.     b     = int2binary{b_int+1};              % binary encoding  
  80.       
  81.     % true answer  
  82.     c_int = a_int + b_int;                    % int version  
  83.     c     = int2binary{c_int+1};              % binary encoding  
  84.       
  85.     % where we'll store our best guess (binary encoded)  
  86.     d     = zeros(size(c));  
  87.     if length(d)<8  
  88.         pause;  
  89.     end  
  90.       
  91.     % total error  
  92.     overallError = 0;  
  93.       
  94.     % difference in output layer, i.e., (target - out)  
  95.     output_deltas = [];  
  96.       
  97.     % values of hidden layer, i.e., S(t)  
  98.     hidden_layer_values = [];  
  99.     cell_gate_values    = [];  
  100.     % initialize S(0) as a zero-vector  
  101.     hidden_layer_values = [hidden_layer_values; zeros(1, hidden_dim)];  
  102.     cell_gate_values    = [cell_gate_values; zeros(1, hidden_dim)];  
  103.       
  104.     % initialize memory gate  
  105.     % hidden layer  
  106.     H = [];  
  107.     H = [H; zeros(1, hidden_dim)];  
  108.     % cell gate  
  109.     C = [];  
  110.     C = [C; zeros(1, hidden_dim)];  
  111.     % in gate  
  112.     I = [];  
  113.     % forget gate  
  114.     F = [];  
  115.     % out gate  
  116.     O = [];  
  117.     % g gate  
  118.     G = [];  
  119.       
  120.     % start to process a sequence, i.e., a forward pass  
  121.     % Note: the output of a LSTM cell is the hidden_layer, and you need to   
  122.     % transfer it to predicted output  
  123.     for position = 0:binary_dim-1  
  124.         % X ------> input, size: 1 x input_dim  
  125.         X = [a(binary_dim - position)-'0' b(binary_dim - position)-'0'];  
  126.           
  127.         % y ------> label, size: 1 x output_dim  
  128.         y = [c(binary_dim - position)-'0']';  
  129.           
  130.         % use equations (1)-(7) in a forward pass. here we do not use bias  
  131.         in_gate     = sigmoid(X * X_i + H(end, :) * H_i + bi);  % equation (1)  
  132.         forget_gate = sigmoid(X * X_f + H(end, :) * H_f + bf);  % equation (2)  
  133.         out_gate    = sigmoid(X * X_o + H(end, :) * H_o + bo);  % equation (3)  
  134.         g_gate      = tan_h(X * X_g + H(end, :) * H_g + bg);    % equation (4)  
  135.         C_t         = C(end, :) .* forget_gate + g_gate .* in_gate;    % equation (5)  
  136.         H_t         = tan_h(C_t) .* out_gate;                          % equation (6)  
  137.           
  138.         % store these memory gates  
  139.         I = [I; in_gate];  
  140.         F = [F; forget_gate];  
  141.         O = [O; out_gate];  
  142.         G = [G; g_gate];  
  143.         C = [C; C_t];  
  144.         H = [H; H_t];  
  145.           
  146.         % compute predict output  
  147.         pred_out = sigmoid(H_t * out_para);  
  148.           
  149.         % compute error in output layer  
  150.         output_error = y - pred_out;  
  151.           
  152.         % compute difference in output layer using derivative  
  153.         % output_diff = output_error * sigmoid_output_to_derivative(pred_out);  
  154.         output_deltas = [output_deltas; output_error];%*sigmoid_output_to_derivative(pred_out)];  
  155. %         output_deltas = [output_deltas; output_error*(pred_out)];  
  156.         % compute total error  
  157.         % note that if the size of pred_out or target is 1 x n or m x n,  
  158.         % you should use other approach to compute error. here the dimension   
  159.         % of pred_out is 1 x 1  
  160.         overallError = overallError + abs(output_error(1));  
  161.           
  162.         % decode estimate so we can print it out  
  163.         d(binary_dim - position) = round(pred_out);  
  164.     end  
  165.       
  166.     % from the last LSTM cell, you need a initial hidden layer difference  
  167.     future_H_diff = zeros(1, hidden_dim);  
  168.       
  169.     % stare back-propagation, i.e., a backward pass  
  170.     % the goal is to compute differences and use them to update weights  
  171.     % start from the last LSTM cell  
  172.     for position = 0:binary_dim-1  
  173.         X = [a(position+1)-'0' b(position+1)-'0'];  
  174.           
  175.         % hidden layer  
  176.         H_t = H(end-position, :);         % H(t)  
  177.         % previous hidden layer  
  178.         H_t_1 = H(end-position-1, :);     % H(t-1)  
  179.         C_t = C(end-position, :);         % C(t)  
  180.         C_t_1 = C(end-position-1, :);     % C(t-1)  
  181.         O_t = O(end-position, :);  
  182.         F_t = F(end-position, :);  
  183.         G_t = G(end-position, :);  
  184.         I_t = I(end-position, :);  
  185.           
  186.         % output layer difference  
  187.         output_diff = output_deltas(end-position, :);  
  188.           
  189.         % hidden layer difference  
  190.         % note that here we consider one hidden layer is input to both  
  191.         % output layer and next LSTM cell. Thus its difference also comes  
  192.         % from two sources. In some other method, only one source is taken  
  193.         % into consideration.  
  194.         % use the equation: delta(l) = (delta(l+1) * W(l+1)) .* f'(z) to  
  195.         % compute difference in previous layers. look for more about the  
  196.         % proof at http://neuralnetworksanddeeplearning.com/chap2.html  
  197. %         H_t_diff = (future_H_diff * (H_i' + H_o' + H_f' + H_g') + output_diff * out_para') ...  
  198. %                    .* sigmoid_output_to_derivative(H_t);  
  199.   
  200.   
  201.         H_t_diff = output_diff * (out_para');% .* sigmoid_output_to_derivative(H_t);  
  202. %         H_t_diff = output_diff * (out_para') .* sigmoid_output_to_derivative(H_t);  
  203. %         future_H_diff = H_t_diff;  
  204. %         out_para_diff = output_diff * (H_t) * sigmoid_output_to_derivative(out_para);  
  205.         out_para_diff =  (H_t') * output_diff;%输出层权重  
  206.   
  207.   
  208.         % out_gate diference  
  209.         O_t_diff = H_t_diff .* tan_h(C_t) .* sigmoid_output_to_derivative(O_t);  
  210.           
  211.         % C_t difference  
  212.         C_t_diff = H_t_diff .* O_t .* tan_h_output_to_derivative(C_t);  
  213.           
  214. %         % C(t-1) difference  
  215. %         C_t_1_diff = C_t_diff .* F_t;  
  216.           
  217.         % forget_gate_diffeence  
  218.         F_t_diff = C_t_diff .* C_t_1 .* sigmoid_output_to_derivative(F_t);  
  219.           
  220.         % in_gate difference  
  221.         I_t_diff = C_t_diff .* G_t .* sigmoid_output_to_derivative(I_t);  
  222.           
  223.         % g_gate difference  
  224.         G_t_diff = C_t_diff .* I_t .* tan_h_output_to_derivative(G_t);  
  225.           
  226.         % differences of X_i and H_i  
  227.         X_i_diff =  X' * I_t_diff;% .* sigmoid_output_to_derivative(X_i);  
  228.         H_i_diff =  (H_t_1)' * I_t_diff;% .* sigmoid_output_to_derivative(H_i);  
  229.           
  230.         % differences of X_o and H_o  
  231.         X_o_diff = X' * O_t_diff;% .* sigmoid_output_to_derivative(X_o);  
  232.         H_o_diff = (H_t_1)' * O_t_diff;% .* sigmoid_output_to_derivative(H_o);  
  233.           
  234.         % differences of X_o and H_o  
  235.         X_f_diff = X' * F_t_diff;% .* sigmoid_output_to_derivative(X_f);  
  236.         H_f_diff = (H_t_1)' * F_t_diff;% .* sigmoid_output_to_derivative(H_f);  
  237.           
  238.         % differences of X_o and H_o  
  239.         X_g_diff = X' * G_t_diff;% .* tan_h_output_to_derivative(X_g);  
  240.         H_g_diff = (H_t_1)' * G_t_diff;% .* tan_h_output_to_derivative(H_g);  
  241.           
  242.         % update  
  243.         X_i_update = X_i_update + X_i_diff;  
  244.         H_i_update = H_i_update + H_i_diff;  
  245.         X_o_update = X_o_update + X_o_diff;  
  246.         H_o_update = H_o_update + H_o_diff;  
  247.         X_f_update = X_f_update + X_f_diff;  
  248.         H_f_update = H_f_update + H_f_diff;  
  249.         X_g_update = X_g_update + X_g_diff;  
  250.         H_g_update = H_g_update + H_g_diff;  
  251.         bi_update = bi_update + I_t_diff;  
  252.         bo_update = bo_update + O_t_diff;  
  253.         bf_update = bf_update + F_t_diff;  
  254.         bg_update = bg_update + G_t_diff;                          
  255.         out_para_update = out_para_update + out_para_diff;  
  256.     end  
  257.       
  258.     X_i = X_i + X_i_update * alpha;   
  259.     H_i = H_i + H_i_update * alpha;  
  260.     X_o = X_o + X_o_update * alpha;   
  261.     H_o = H_o + H_o_update * alpha;  
  262.     X_f = X_f + X_f_update * alpha;   
  263.     H_f = H_f + H_f_update * alpha;  
  264.     X_g = X_g + X_g_update * alpha;   
  265.     H_g = H_g + H_g_update * alpha;  
  266.     bi = bi + bi_update * alpha;  
  267.     bo = bo + bo_update * alpha;  
  268.     bf = bf + bf_update * alpha;  
  269.     bg = bg + bg_update * alpha;  
  270.     out_para = out_para + out_para_update * alpha;  
  271.       
  272.     X_i_update = X_i_update * 0;   
  273.     H_i_update = H_i_update * 0;  
  274.     X_o_update = X_o_update * 0;   
  275.     H_o_update = H_o_update * 0;  
  276.     X_f_update = X_f_update * 0;   
  277.     H_f_update = H_f_update * 0;  
  278.     X_g_update = X_g_update * 0;   
  279.     H_g_update = H_g_update * 0;  
  280.     bi_update = 0;  
  281.     bf_update = 0;  
  282.     bo_update = 0;  
  283.     bg_update = 0;  
  284.     out_para_update = out_para_update * 0;  
  285.       
  286.     if(mod(j,1000) == 0)  
  287.         if 1%overallError > 1  
  288.             err = sprintf('Error:%s\n', num2str(overallError)); fprintf(err);  
  289.         end  
  290.         allErr = [allErr overallError];  
  291. %         try  
  292.             d = bin2dec(num2str(d));  
  293. %         catch  
  294. %             disp(d);  
  295. %         end  
  296.         if 1%overallError>1  
  297.         pred = sprintf('Pred:%s\n',dec2bin(d,8)); fprintf(pred);  
  298.         Tru = sprintf('True:%s\n', num2str(c)); fprintf(Tru);  
  299.         end  
  300.         out = 0;  
  301.         tmp = dec2bin(d,8);  
  302.         for i = 1:8             
  303.             out = out + str2double(tmp(8-i+1)) * power(2,i-1);  
  304.         end  
  305.         if 1%overallError>1  
  306.         fprintf('%d + %d = %d\n',a_int,b_int,out);  
  307.         sep = sprintf('-------%d------\n', j); fprintf(sep);  
  308.         end  
  309.     end  
  310. end  
  311. figure;plot(allErr);  
  312. function output = sigmoid(x)  
  313.     output = 1./(1+exp(-x));  
  314. end  
  315.   
  316.   
  317. function y = sigmoid_output_to_derivative(output)  
  318.     y = output.*(1-output);  
  319. end  
  320.   
  321.   
  322. function y = tan_h_output_to_derivative(x)  
  323.     y = (1-x.^2);  
  324. end  
  325.   
  326.   
  327. function y=tan_h(x)  
  328. y=(exp(x)-exp(-x))./(exp(x)+exp(-x));  
  329. end  

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值