DeepLearning4J入门——使用LSTM进行大盘回归

       LSTM是递归神经网络(RNN)的一个变种,相较于RNN而言,解决了记忆消失的问题,用来处理序列问题是一个很好的选择。本文主要介绍如何使用DL4J中的LSTM来执行回归分析。如果不清楚RNN和LSTM,可以先阅读 LSTM和递归网络教程 以及 通过DL4J使用递归网络 ,特别是不熟悉RNN输入和预测方式的强烈建议先阅读这两个教程。如果不太会建立DL4J的工程,建议在其样例工程中进行本实验。

      言归正传,文本通过使用 LSTM对上证指数历史数据进行回归学习,并给出一个初始序列预测之后20天的大盘收盘价格来演示如何使用LSTM处理简单的序列回归问题。首先是准备数据,可以下载例子中我使用的数据集。那么接下来的问题就分成如下几步:

1. 读入训练数据,并处理成一个DataIterator;

2. 构建一个LSTM的递归神经网络;

3. 迭代训练,并输出预测结果;

4. 调参和优化。


一.处理训练数据

我们的数据是上证指数每个交易日的基本数据,格式为:

股票代码 日期开盘价 收盘价最高价  最低价成交量  成交额涨跌幅

        这个文件中的数据是倒序的,也就是说新的数据在最前面,因此在读取数据时需要做一次倒转。我将读取文件的方法放在Dataiterator中。DL4J给出了序列数据处理的DataIterator,但是在本例中我们是自己实现一个DataIterator。代码如下:

package edu.zju.cst.krselee.example.stock;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;

/**
 * Created by kexi.lkx on 2016/8/23.
 */
public class StockDataIterator  implements DataSetIterator {

    private static final int VECTOR_SIZE = 6;
    //每批次的训练数据组数
    private int batchNum;

    //每组训练数据长度(DailyData的个数)
    private int exampleLength;

    //数据集
    private List<DailyData> dataList;

    //存放剩余数据组的index信息
    private List<Integer> dataRecord;

    private double[] maxNum;
    /**
     * 构造方法
     * */
    public StockDataIterator(){
        dataRecord = new ArrayList<>();
    }

    /**
     * 加载数据并初始化
     * */
    public boolean loadData(String fileName, int batchNum, int exampleLength){
        this.batchNum = batchNum;
        this.exampleLength = exampleLength;
        maxNum = new double[6];
        //加载文件中的股票数据
        try {
            readDataFromFile(fileName);
        }catch (Exception e){
            e.printStackTrace();
            return false;
        }
        //重置训练批次列表
        resetDataRecord();
        return
  • 5
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 22
    评论
下面是一个简单的使用MATLAB Deep Learning Toolbox中的LSTM网络进行脑电数据二分类的示例: ```matlab % 加载数据 load eegdata % 将数据划分为训练集和测试集 numObservations = size(X,2); idx = randperm(numObservations); numTrain = floor(0.7*numObservations); idxTrain = idx(1:numTrain); idxTest = idx(numTrain+1:end); XTrain = X(:,idxTrain,:); YTrain = categorical(Y(idxTrain)); XTest = X(:,idxTest,:); YTest = categorical(Y(idxTest)); % 创建LSTM网络 numFeatures = size(XTrain,1); numClasses = numel(categories(YTrain)); numHiddenUnits = 100; layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]; % 配置LSTM网络的训练选项 options = trainingOptions('adam', ... 'MaxEpochs', 30, ... 'MiniBatchSize', 64, ... 'InitialLearnRate', 0.01, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor', 0.1, ... 'LearnRateDropPeriod', 10, ... 'Shuffle','every-epoch', ... 'ValidationData',{XTest,YTest}, ... 'ValidationFrequency',30, ... 'Verbose',false, ... 'Plots','training-progress'); % 训练LSTM网络 net = trainNetwork(XTrain,YTrain,layers,options); % 使用测试集评估训练好的LSTM网络 YPred = classify(net,XTest); % 计算分类准确率 accuracy = sum(YPred == YTest)/numel(YTest); fprintf('分类准确率为 %.2f%%\n',accuracy*100); ``` 在此示例中,我们加载了脑电数据,并将其划分为训练集和测试集。然后,我们创建了一个简单的LSTM网络,使用Adam优化器训练了30个epoch,并使用测试集评估了网络。最后,我们计算了分类准确率。请注意,此示例仅用于演示如何使用LSTM网络进行脑电数据分类,实际上,您可能需要更复杂的网络结构和更多的训练数据来获得更好的分类性能。
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值