LSTM(长短期记忆人工神经网络)实现

简介


长短期记忆人工神经网络(Long-Short Term Memory, LSTM)是一种时间递归神经网络(RNN),论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。

由于其结构和RNN很相似,就是将单一的激活函数换成更为复杂的结构。前面《RNN(循环神经网络)训练手写数字》的数据处理和很多代码都有共通之处,本文就从简😊。

2641798-e963f2faa2d6e9b0.png

公式


LSTM的结构有很多种形式,但是都大同小异,主要都包含输入门、输出门、遗忘门。
本文实现的一种较为流行的结构GRU(Gated Recurrent Unit)。公式与结构图如下:

2641798-9f2ad6b1efe8befa.png

实现


相比较于简单的RNN网络,LSTM训练的参数更多,单个块的结构也更复杂。实现中,我将输入的误差也反传了,这样可以很方便的实现多层LSTM网络,或者与RNN/CNN网络结合使用。
主体代码如下:

//
//  MLLstm.m
//  LSTM
//
//  Created by Jiao Liu on 11/12/16.
//  Copyright © 2016 ChangHong. All rights reserved.
//

#import "MLLstm.h"

@implementation MLLstm

#pragma mark - Inner Method

+ (double)truncated_normal:(double)mean dev:(double)stddev
{
    double outP = 0.0;
    do {
        static int hasSpare = 0;
        static double spare;
        if (hasSpare) {
            hasSpare = 0;
            outP = mean + stddev * spare;
            continue;
        }
        
        hasSpare = 1;
        static double u,v,s;
        do {
            u = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            v = (rand() / ((double) RAND_MAX)) * 2.0 - 1.0;
            s = u * u + v * v;
        } while ((s >= 1.0) || (s == 0.0));
        s = sqrt(-2.0 * log(s) / s);
        spare = v * s;
        outP = mean + stddev * u * s;
    } while (fabsl(outP) > 2*stddev);
    return outP;
}

+ (double *)fillVector:(double)num size:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    vDSP_vfillD(&num, outP, 1, size);
    return outP;
    
}

+ (double *)weight_init:(int)size
{
    double *outP = malloc(sizeof(double) * size);
    for (int i = 0; i < size; i++) {
        outP[i] = [MLLstm truncated_normal:0 dev:0.1];
    }
    return outP;
}

+ (double *)bias_init:(int)size
{
    return [MLLstm fillVector:0.1f size:size];
}

+ (double *)tanh:(double *)input size:(int)size
{
    for (int i = 0; i < size; i++) {
        double num = input[i];
        if (num > 20) {
            input[i] = 1;
        }
        else if (num < -20)
        {
            input[i] = -1;
        }
        else
        {
            input[i] = (exp(num) - exp(-num)) / (exp(num) + exp(-num));
        }
    }
    return input;
}

+ (double *)sigmoid:(double *)input size:(int)size
{
    for (int i = 0; i < size; i++) {
        double num = input[i];
        if (num > 20) {
            input[i] = 1;
        }
        else if (num < -20)
        {
            input[i] = 0;
        }
        else
        {
            input[i] = exp(num) / (exp(num) + 1);
        }
    }
    return input;
}

#pragma mark - Init

- (id)initWithNodeNum:(int)num layerSize:(int)size dataDim:(int)dim
{
    self = [super init];
    if (self) {
        _nodeNum = num;
        _layerSize = size;
        _dataDim = dim;
        [self setupNet];
    }
    return self;
}

- (id)init
{
    self = [super init];
    if (self) {
        [self setupNet];
    }
    return self;
}

- (void)setupNet
{
    _hState = calloc(_layerSize * _nodeNum, sizeof(double));
    _rState = calloc(_layerSize * _nodeNum, sizeof(double));
    _zState = calloc(_layerSize * _nodeNum, sizeof(double));
    _hbState = calloc(_layerSize * _nodeNum, sizeof(double));
    _output = calloc(_layerSize * _dataDim, sizeof(double));
    _backLoss = calloc(_layerSize * _dataDim, sizeof(double));
    
    _rW = [MLLstm weight_init:_nodeNum * _dataDim];
    _rU = [MLLstm weight_init:_nodeNum * _nodeNum];
    _rBias = [MLLstm bias_init:_nodeNum];
    _zW = [MLLstm weight_init:_nodeNum * _dataDim];
    _zU = [MLLstm weight_init:_nodeNum * _nodeNum];
    _zBias = [MLLstm bias_init:_nodeNum];
    _hW = [MLLstm weight_init:_nodeNum * _dataDim];
    _hU = [MLLstm weight_init:_nodeNum * _nodeNum];
    _hBias = [MLLstm bias_init:_nodeNum];
    _outW = [MLLstm weight_init:_dataDim * _nodeNum];
    _outBias = [MLLstm bias_init:_dataDim];
}

- (double *)forwardPropagation:(double *)input
{
    _input = input;
    // clean data
    double zero = 0;
    vDSP_vfillD(&zero, _output, 1, _layerSize * _dataDim);
    vDSP_vfillD(&zero, _hState, 1, _layerSize * _nodeNum);
    vDSP_vfillD(&zero, _rState, 1, _layerSize * _nodeNum);
    vDSP_vfillD(&zero, _zState, 1, _layerSize * _nodeNum);
    vDSP_vfillD(&zero, _hbState, 1, _layerSize * _nodeNum);
    vDSP_vfillD(&zero, _backLoss, 1, _layerSize * _dataDim);
    
    double *temp1 = calloc(_nodeNum, sizeof(double));
    double *temp2 = calloc(_nodeNum, sizeof(double));
    double *temp3 = calloc(_nodeNum, sizeof(double));
    double *one = [MLLstm fillVector:1 size:_nodeNum];
    for (int i = 0; i < _layerSize; i++) {
        //rj =σ  [Wr*(xt)]j +  Ur*h⟨t−1⟩ + rBias]
        if (i == 0) {
            vDSP_mmulD(_rW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vaddD(temp1, 1, _rBias, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_mmulD(_rW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_mmulD(_rU, 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, _rBias, 1, temp1, 1, _nodeNum);
        }
        [MLLstm sigmoid:temp1 size:_nodeNum];
        vDSP_vaddD((_rState + i * _nodeNum), 1, temp1, 1, (_rState + i * _nodeNum), 1, _nodeNum);
        
        //zj =σ  [Wz*(xt)]j +  Uz*h⟨t−1⟩ + zBias]
        if (i == 0) {
            vDSP_mmulD(_zW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vaddD(temp1, 1, _zBias, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_mmulD(_zW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_mmulD(_zU, 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, _zBias, 1, temp1, 1, _nodeNum);
        }
        [MLLstm sigmoid:temp1 size:_nodeNum];
        vDSP_vaddD((_zState + i * _nodeNum), 1, temp1, 1, (_zState + i * _nodeNum), 1, _nodeNum);
        
        //h ̃⟨t⟩ = tanh  {[W*(xt)] +  U * [r ⊙ h⟨t−1⟩] + hBias}
        if (i == 0) {
            vDSP_mmulD(_hW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vaddD(temp1, 1, _hBias, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_mmulD(_hW, 1, (_input + i * _dataDim), 1, temp1, 1, _nodeNum, 1, _dataDim);
            vDSP_vmulD((_rState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum);
            vDSP_mmulD(_hU, 1, temp2, 1, temp3, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp3, 1, temp1, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, _hBias, 1, temp1, 1, _nodeNum);
        }
        [MLLstm tanh:temp1 size:_nodeNum];
        vDSP_vaddD((_hbState + i * _nodeNum), 1, temp1, 1, (_hbState + i * _nodeNum), 1, _nodeNum);
        
        //h⟨t⟩ = zj⊙ h⟨t−1⟩ + (1 − zj)⊙ h ̃⟨t⟩
        if (i == 0) {
            vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
            vDSP_vmulD((_hbState + i * _nodeNum), 1, temp1, 1, temp1, 1, _nodeNum);
        }
        else
        {
            vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
            vDSP_vmulD((_hbState + i * _nodeNum), 1, temp1, 1, temp1, 1, _nodeNum);
            vDSP_vmulD((_zState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp2, 1, _nodeNum);
            vDSP_vaddD(temp1, 1, temp2, 1, temp1, 1, _nodeNum);
        }
        vDSP_vaddD((_hState + i * _nodeNum), 1, temp1, 1, (_hState + i * _nodeNum), 1, _nodeNum);
        
        // output
        vDSP_mmulD(_outW, 1, (_hState + i * _nodeNum), 1, (_output + i * _dataDim), 1, _dataDim, 1, _nodeNum);
        vDSP_vaddD(_outBias, 1, (_output + i * _dataDim), 1, (_output + i * _dataDim), 1, _dataDim);
    }
    free(one);
    free(temp1);
    free(temp2);
    free(temp3);
    
    return _output;
}

- (double *)backPropagation:(double *)loss
{
    double *flowLoss = calloc(_nodeNum, sizeof(double));
    double *outTW = calloc(_nodeNum * _dataDim, sizeof(double));
    double *outLoss = calloc(_nodeNum, sizeof(double));
    double *outWLoss = calloc(_dataDim * _nodeNum, sizeof(double));
    double *temp1 = calloc(_nodeNum, sizeof(double));
    double *one = [MLLstm fillVector:1 size:_nodeNum];
    double *zLoss = calloc(_nodeNum, sizeof(double));
    double *hbLoss = calloc(_nodeNum, sizeof(double));
    double *inWLoss = calloc(_nodeNum * _dataDim, sizeof(double));
    double *rLoss = calloc(_nodeNum, sizeof(double));
    double *tU = calloc(_nodeNum * _nodeNum, sizeof(double));
    double *uLoss = calloc(_nodeNum * _nodeNum, sizeof(double));
    double *tW = calloc(_dataDim * _nodeNum, sizeof(double));
    double *temp2 = calloc(_dataDim, sizeof(double));
    for (int i = _layerSize - 1; i >= 0; i--) {
        // update output parameters
        vDSP_vaddD(_outBias, 1, (loss + i * _dataDim), 1, _outBias, 1, _dataDim);
        vDSP_mtransD(_outW, 1, outTW, 1, _nodeNum, _dataDim);
        vDSP_mmulD(outTW, 1, (loss + i * _dataDim), 1, outLoss, 1, _nodeNum, 1, _dataDim);
        vDSP_mmulD((loss + i * _dataDim), 1, (_hState + i * _nodeNum), 1, outWLoss, 1, _dataDim, _nodeNum, 1);
        vDSP_vaddD(_outW, 1, outWLoss, 1, _outW, 1, _dataDim * _nodeNum);
        
        // h(t) back loss
        if (i != _layerSize - 1) {
            vDSP_vaddD(outLoss, 1, flowLoss, 1, outLoss, 1, _nodeNum);
        }
        if (i > 0) {
            vDSP_vsubD((_hState + (i-1) * _nodeNum), 1, (_hbState + i * _nodeNum), 1, temp1, 1, _nodeNum);
            vDSP_vmulD(outLoss, 1, temp1, 1, zLoss, 1, _nodeNum);
            
            vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
            vDSP_vmulD(outLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
        }
        else
        {
            vDSP_vmulD(outLoss, 1, (_hbState + i * _nodeNum), 1, zLoss, 1, _nodeNum);
        }
        // σ` = f(x)*(1-f(x))
        vDSP_vsubD((_zState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
        vDSP_vmulD(temp1, 1, (_zState + i * _nodeNum), 1, temp1, 1, _nodeNum);
        vDSP_vmulD(temp1, 1, zLoss, 1, zLoss, 1, _nodeNum);
        
        vDSP_vmulD(outLoss, 1, (_zState + i * _nodeNum), 1, hbLoss, 1, _nodeNum);
        // tanh` =  1-f(x)**2
        vDSP_vsqD((_hbState + i * _nodeNum), 1, temp1, 1, _nodeNum);
        vDSP_vsubD(temp1, 1, one, 1, temp1, 1, _nodeNum);
        vDSP_vmulD(hbLoss, 1, temp1, 1, hbLoss, 1, _nodeNum);
        
        // update h`(t) parameters
        vDSP_vaddD(_hBias, 1, hbLoss, 1, _hBias, 1, _nodeNum);
        vDSP_mtransD(_hW, 1, tW, 1, _dataDim, _nodeNum);
        vDSP_mmulD(tW, 1, hbLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
        vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
        vDSP_mmulD(hbLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
        vDSP_vaddD(_hW, 1, inWLoss, 1, _hW, 1, _nodeNum * _dataDim);

        if (i > 0) {
            vDSP_mtransD(_hU, 1, tU, 1, _nodeNum, _nodeNum);
            vDSP_mmulD(tU, 1, hbLoss, 1, rLoss, 1, _nodeNum, 1, _nodeNum);
            vDSP_vmulD(rLoss, 1, (_hState + (i-1) * _nodeNum), 1, rLoss, 1, _nodeNum);
            vDSP_vsubD((_rState + i * _nodeNum), 1, one, 1, temp1, 1, _nodeNum);
            vDSP_vmulD(temp1, 1, (_rState + i * _nodeNum), 1, temp1, 1, _nodeNum);
            vDSP_vmulD(temp1, 1, rLoss, 1, rLoss, 1, _nodeNum);
            
            vDSP_mmulD(tU, 1, hbLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
            vDSP_vmulD(temp1, 1, (_rState + i * _nodeNum), 1, temp1, 1, _nodeNum);
            vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
            
            vDSP_vmulD((_rState + i * _nodeNum), 1, (_hState + (i-1) * _nodeNum), 1, temp1, 1, _nodeNum);
            vDSP_mmulD(hbLoss, 1, temp1, 1, uLoss, 1, _nodeNum, _nodeNum, 1);
            vDSP_vaddD(_hU, 1, uLoss, 1, _hU, 1, _nodeNum * _nodeNum);
        }
        
        // update z(t) parameters
        vDSP_vaddD(_zBias, 1, zLoss, 1, _zBias, 1, _nodeNum);
        vDSP_mtransD(_zW, 1, tW, 1, _dataDim, _nodeNum);
        vDSP_mmulD(tW, 1, zLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
        vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
        vDSP_mmulD(zLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
        vDSP_vaddD(_zW, 1, inWLoss, 1, _zW, 1, _nodeNum * _dataDim);
        
        if (i > 0) {
            vDSP_mtransD(_zU, 1, tU, 1, _nodeNum, _nodeNum);
            vDSP_mmulD(tU, 1, zLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
            
            vDSP_mmulD(zLoss, 1, (_hState + (i-1) * _nodeNum), 1, uLoss, 1, _nodeNum, _nodeNum, 1);
            vDSP_vaddD(_zU, 1, uLoss, 1, _zU, 1, _nodeNum * _nodeNum);
        }
        
        // update r(t) parameters
        if (i > 0) {
            vDSP_vaddD(_rBias, 1, rLoss, 1, _rBias, 1, _nodeNum);
            vDSP_mtransD(_rW, 1, tW, 1, _dataDim, _nodeNum);
            vDSP_mmulD(tW, 1,rLoss, 1, temp2, 1, _dataDim, 1, _nodeNum);
            vDSP_vaddD((_backLoss + i * _dataDim), 1, temp2, 1, (_backLoss + i * _dataDim), 1, _dataDim);
            vDSP_mmulD(rLoss, 1, (_input + i * _dataDim), 1, inWLoss, 1, _nodeNum, _dataDim, 1);
            vDSP_vaddD(_rW, 1, inWLoss, 1, _rW, 1, _nodeNum * _dataDim);
            
            vDSP_mtransD(_rU, 1, tU, 1, _nodeNum, _nodeNum);
            vDSP_mmulD(tU, 1, rLoss, 1, temp1, 1, _nodeNum, 1, _nodeNum);
            vDSP_vaddD(flowLoss, 1, temp1, 1, flowLoss, 1, _nodeNum);
            
            vDSP_mmulD(rLoss, 1, (_hState + (i-1) * _nodeNum), 1, uLoss, 1, _nodeNum, _nodeNum, 1);
            vDSP_vaddD(_rU, 1, uLoss, 1, _rU, 1, _nodeNum * _nodeNum);
        }
    }
    
    free(flowLoss);
    free(outTW);
    free(outLoss);
    free(outWLoss);
    free(temp1);
    free(one);
    free(zLoss);
    free(hbLoss);
    free(inWLoss);
    free(rLoss);
    free(tU);
    free(uLoss);
    free(tW);
    free(temp2);
    return _backLoss;
}

@end

结语


这里同样用MNIST数据训练了单层LSTM的效果,参数选用单个神经元节点500,迭代1300,一次5张图片,得到90%左右正确率。

多次尝试发现神经元节点个数越大,单次迭代训练时间越长,准确率越高。所以将节点个数设到500,为了加快速度将一次迭代图片数由RNN网络的100降到5张,但是整个过程还是花了3个多小时😭。其效果不及CNN、RNN在相似环境下的表现。

有兴趣的朋友可以点这里看完整代码

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值