LSTM是递归神经网络(RNN)的一个变种,相较于RNN而言,解决了记忆消失的问题,用来处理序列问题是一个很好的选择。本文主要介绍如何使用DL4J中的LSTM来执行回归分析。如果不清楚RNN和LSTM,可以先阅读 LSTM和递归网络教程 以及 通过DL4J使用递归网络 ,特别是不熟悉RNN输入和预测方式的强烈建议先阅读这两个教程。如果不太会建立DL4J的工程,建议在其样例工程中进行本实验。
言归正传,文本通过使用 LSTM对上证指数历史数据进行回归学习,并给出一个初始序列预测之后20天的大盘收盘价格来演示如何使用LSTM处理简单的序列回归问题。首先是准备数据,可以下载例子中我使用的数据集。那么接下来的问题就分成如下几步:
1. 读入训练数据,并处理成一个DataIterator;
2. 构建一个LSTM的递归神经网络;
3. 迭代训练,并输出预测结果;
4. 调参和优化。
一.处理训练数据
我们的数据是上证指数每个交易日的基本数据,格式为:
股票代码 日期开盘价 收盘价最高价 最低价成交量 成交额涨跌幅
这个文件中的数据是倒序的,也就是说新的数据在最前面,因此在读取数据时需要做一次倒转。我将读取文件的方法放在Dataiterator中。DL4J给出了序列数据处理的DataIterator,但是在本例中我们是自己实现一个DataIterator。代码如下:
package edu.zju.cst.krselee.example.stock;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
/**
* Created by kexi.lkx on 2016/8/23.
*/
public class StockDataIterator implements DataSetIterator {
private static final int VECTOR_SIZE = 6;
//每批次的训练数据组数
private int batchNum;
//每组训练数据长度(DailyData的个数)
private int exampleLength;
//数据集
private List<DailyData> dataList;
//存放剩余数据组的index信息
private List<Integer> dataRecord;
private double[] maxNum;
/**
* 构造方法
* */
public StockDataIterator(){
dataRecord = new ArrayList<>();
}
/**
* 加载数据并初始化
* */
public boolean loadData(String fileName, int batchNum, int exampleLength){
this.batchNum = batchNum;
this.exampleLength = exampleLength;
maxNum = new double[6];
//加载文件中的股票数据
try {
readDataFromFile(fileName);
}catch (Exception e){
e.printStackTrace();
return false;
}
//重置训练批次列表
resetDataRecord();
return