这其实是一个利用LSTM递归网络进行序列分类的问题,根据数据趋势把数据分成6个类正常,循环,阶跃上升,阶跃下降,趋势向上,趋势向下
数据连接:https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data
图像连接:https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg
处理步骤如下:
1.下载并准备数据
(a)600个数据450个训练150测试
(b)把数据转成适当的格式并用csv序列读取器读取
格式:每个文件一个代表一个时间序列,还有一个独立的标签文件,例如train/features/0.csv是特征文件,对应标签文件train/labels/0.csv,由于数据是单变量序列,csv 只有一列数据,没列有多个值,每行一个时间步,标签文件只有一个标签值
2.用csv序列读取器装载训练数据,用序列读取数据迭代器转换数据集
3.规范化数据,收集训练数据的统计信息,训练数据和测试数据用同样的方法规范化
4.配置网络,我们使用小型的lstm层和一个rnn输出层
5.训练40步,每步打印测试数据的准确率和f1
代码如下:
public class UCISequenceClassificationExample { private static final Logger log = LoggerFactory.getLogger(UCISequenceClassificationExample.class);//声明了log类,这样用logger.info打印信息 //'baseDir': Base directory for the data. Change this if you want to save the data somewhere else private static File baseDir = new File("src/main/resources/uci/");//文件路径,实战中文件很多时我们很可能这样处理数据 private static File baseTrainDir = new File(baseDir, "train");//file类里两个参数,第一个是上级目录,第二个是下级目录 private static File featuresDirTrain = new File(baseTrainDir, "features");//定位到特征目录 private static File labelsDirTrain = new File(baseTrainDir, "labels");//定位到标签目录 private static File baseTestDir = new File(baseDir, "test");//定位到测试目录 private static File featuresDirTest = new File(baseTestDir, "features");//定位到测试属性目录 private static File labelsDirTest = new File(baseTestDir, "labels");//定位到测试标签目录 public static void main(String[] args) throws Exception { downloadUCIData();//下载在uci数据,可以先跳转看看这个函数 // ----- Load the training data ----- //Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();//搞一个csv读取器 trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));//NumberedFileInputSplit是InputSplit的一个实现,第一个参数是带有正则的路径,后两个参数是正则的范围,用NumberedFileInputSplit初始化训练属性 SequenceRecordReader trainLabels = new CSVSequenceRecordReader();//同上初始化训练标签 trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449)); int miniBatchSize = 10;//10个数据为一批 int numLabelClasses = 6;//一共有6类标签 DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);//把训练的属性和标签数据合成训练数据,传入之前的参数,false是不做回归,最后一个参数是处理出入标签排列模式的枚举类型,例如一对多,多对一,有10个时间步,但只有一个输出等,有EQUAL_LENGTH,ALIGN_START,ALIGN_END三种 EQUAL_LENGTH:默认模式,所有样本标签和输入时间序列等长 ALIGN_START:标签位于输入时间序列的第一个时间步 ALIGN_END:标签位于输入时间序列的最后一个时间步 如果输入时间序列的长度不一致就右补0到等长,如果使用了后两种模式,会使用掩码函数,返回的数据集包括输入和掩码向量集合,掩码用来描述输入和标签的对应关系,本例采用的是标签位于最后一个时间步的模式,即多时间步输入只对应一个输出,这个输出位于最后一个时间步,自己想想,另外rnn用的是序列读取器,而不是cnn的RecordReaderDataSetIterator,他俩都继承自DataSetIterator //Normalize the training data DataNormalization normalizer = new NormalizerStandardize();//声明规范化器 normalizer.fit(trainData); //Collect training data statistics//收集训练数据统计信息,这时候trainData已经变了 trainData.reset();//把训练数据变回来 System.out.println(trainData.next());//把训练数据规范化 //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized trainData.setPreProcessor(normalizer);//对测试数据做同样的处理 System.out.println(trainData.inputColumns()); System.out.println(trainData.next()); System.exit(0); // ----- Load the test data ----- //Same process as for the training data. SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader testLabels = new CSVSequenceRecordReader(); testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data//注意这里使用的是训练数据提取的规范化器 // ----- Configure the network ----- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//构建神经网络配置,基本都和之前的参数一样 .seed(123) //Random number generator seed for improved repeatability. Optional. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .weightInit(WeightInit.XAVIER) .updater(Updater.NESTEROVS).momentum(0.9) .learningRate(0.005) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set//这里设置了一个梯队规范化器,用来防止梯度消失,参数是梯度规范化器的枚举类型,包括以下几种: None:默认,不使用梯度规范 RenormalizeL2PerLayer:除以该层梯度的2范数的方式 RenormalizeL2PerParamType:权重梯度和偏移梯度分别除以各自的2范数的方式 ClipElementWiseAbsoluteValue:修建每个梯度的方式,例如set g <- sign(g)*max(maxAllowedValue,|g|),其中maxAllowedValue是阈值 ClipL2PerLayer:类似于RenormalizeL2PerLayer,如果梯度2范数在某个范围内不变,在范围外除以梯度的2范数 ClipL2PerParamType:类似于ClipL2PerLayer,只不过是把权重梯度和偏移梯度分开处理 .gradientNormalizationThreshold(0.5)//由于使用了ClipElementWiseAbsoluteValue,设置了阈值为0.5 .list() .layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10).build())//设置输入层为LSTM,1个输入10个输出 .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)//设置输出层为RNN,10个输入6个输出,这里我们和之前的对比一下,在csv数据做回归一文中,我们处理的也是类似序列,当时输入输出大小都是1,在递归网络一文中,我们的输入输出大小都是唯一字母数,本文中输入也是序列,输入大小是1,输出是类别数,注意之间的区别 .activation("softmax").nIn(10).nOut(numLabelClasses).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations//装载配置,初始化网络,设置监听器 // ----- Train the network, evaluating the test set performance at each epoch ----- int nEpochs = 40; String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f"; for (int i = 0; i < nEpochs; i++) {//按步循环 net.fit(trainData);//加载数据 //Evaluate on the test set: Evaluation evaluation = net.evaluate(testData);//评估测试集 log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));//打印评估结果,f1是兼顾了分类模型的准确率和召回率的二分类评估指标 testData.reset();//重置测试数据 trainData.reset();//重置训练数据 } log.info("----- Example Complete -----"); } //This method downloads the data, and converts the "one time series per line" format into a suitable //CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read.//这个方法把每行一个时间步的序列转成框架可读的csv序列 private static void downloadUCIData() throws Exception { if (baseDir.exists()) return; //Data already exists, don't download it again//看看文件是否存在 String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data"; String data = IOUtils.toString(new URL(url));//下载相当于根据url地址读取字符串数据,java真是强大 String[] lines = data.split("\n");//按回车划分成行 //Create directories baseDir.mkdir();//创建目录,注意顺序,不能先建子目录再建上级目录 baseTrainDir.mkdir(); featuresDirTrain.mkdir(); labelsDirTrain.mkdir(); baseTestDir.mkdir(); featuresDirTest.mkdir(); labelsDirTest.mkdir(); int lineCount = 0; List<Pair<String, Integer>> contentAndLabels = new ArrayList<>();//弄一个list,每个元素是内容和标签的元组 for (String line : lines) {//遍历数据 String transposed = line.replaceAll(" +", "\n");//把+替换成换行,这样每行数据在csv里就是一个列 //Labels: first 100 examples (lines) are label 0, second 100 examples are label 1, and so on contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100));//按每100个数据标签递增的规则装载list } //Randomize and do a train/test split: Collections.shuffle(contentAndLabels, new Random(12345)); int nTrain = 450; //75% train, 25% test//训练集450 int trainCount = 0; int testCount = 0; for (Pair<String, Integer> p : contentAndLabels) {//遍历每个数据//Write output in a format we can read, in the appropriate locations File outPathFeatures;//声明输出属性类,先声明,循环赋值,对象少,占用空间少 File outPathLabels;//声明输出标签类 if (trainCount < nTrain) {//凑够训练集 outPathFeatures = new File(featuresDirTrain, trainCount + ".csv");//给输出属性赋值 outPathLabels = new File(labelsDirTrain, trainCount + ".csv");//给输出类别赋值 trainCount++;//统计训练数据 } else { outPathFeatures = new File(featuresDirTest, testCount + ".csv"); outPathLabels = new File(labelsDirTest, testCount + ".csv"); testCount++; } FileUtils.writeStringToFile(outPathFeatures, p.getFirst());//把属性写入文件 FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString());//把类别写入文件 } } }