数据加载
本例中的数据加载是用的异步方式,所以看主线程代码,根本找不到加载数据的地方。开启异步加载数据线程是在model.fit(trainIter)这个方法中。首先判断数据集迭代器是否支持异步加载if (iterator.asyncSupported()) 。由于本例中的是RecordReaderDataSetIterator,其asyncSupported()方法是返回true的,支持异步加载。创建AsyncDataSetIterator对象,在AsyncDataSetIterator的构造方法中,创建了AsyncPrefetchThread线程,并启动。
AsyncPrefetchThread线程启动后,操作系统会通过线程调度运行run()方法。其读取linear_data_train.csv文件的线程堆栈如下:
Daemon Thread [ADSI prefetch thread] (Suspended)
LineIterator.hasNext() line: 95
CSVRecordReader(LineRecordReader).hasNext() line: 105
CSVRecordReader.hasNext() line: 140
RecordReaderDataSetIterator.hasNext() line: 387
AsyncDataSetIterator$AsyncPrefetchThread.run() line: 412
run()方法中,iterator开始是空的,while (iterator.hasNext() && shouldWork.get())语句中的iterator.hasNext()从文件中加载数据。最终在LineIterator.hasNext()的String line = bufferedReader.readLine()语句中读到数据。smth = iterator.next()这一句是将数据从迭代器中读到数据集中,其中在CSVParser的parseLine()方法中将一行数据按逗号分割为字符串数组。数据集smth放到queue中。queue是BlockingQueue(接口,具体类型是LinkedBlockingQueue)的队列,这是个线程安全的阻塞队列。使用put()和take(),会有阻塞效果。不是很了解BlockingQueue的同学可自行百度。
异步线程的run()方法将数据集放到queue中后,由MultiLayerNetwork中fit方法的if (!iter.hasNext() && iter.resetSupported())语句中iter.hasNext()取出。其具体位置是AsyncDataSetIterator中hasNext()方法的nextElement = buffer.take()。这里的buffer和上文提到的queue是一个变量,或者说是指向同一块内存的两个变量。在异步线程的run()方法使用put()方法将数据集放到queue。主线程中使用take()方法从buffer中获取数据。由于put()和take()自身有阻塞效果,使用它们可以不用显示的写线程等待,就可以实现生产者-消费者模式。
跟代码时发现smth = iterator.next()这一句也能调用到LineIterator.hasNext()的String line = bufferedReader.readLine()。这个应该是重复了。前面已经调用过,缓存起来,后面就不用在调用了。
一条数据记录中,区分特征和标签的依据是MLPClassifierLinear的DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2)。其中的参数0表示标签是第0列,2表示标签有两种可能值。
在RecordReaderMultiDataSetIterator类中initializeUnderlying()方法的语句builder.addOutputOneHot(READER_KEY, labelIndex, numPossibleLabels)语句设置标签相关参数。这个OneHot很重要。由于数据文件中标签只有一列,但是神经网络输出层有两个神经元。通过OneHot属性,程序可以确定另一个神经元的理想值。initializeUnderlying()方法的语句builder.addInput(READER_KEY, inputFrom, inputTo)设置特征的开始列和结束列。
在RecordReaderMultiDataSetIterator中nextMultiDataSet()方法的Pair<INDArray[], INDArray[]> features = convertFeaturesOrLabels()和Pair<INDArray[], INDArray[]> labels = convertFeaturesOrLabels()语句分别从数据记录中读出特征和标签(输入和理想输出)。它们回调用convertWritables()方法构造特征和标签数组,并向里面填充数据。在构造标签数组时,convertWritables()根据OneHot为true和标签有两种可能值确定标签数组的具体结构。从这里可以看出,标签的可能值数量和输出层的神经元个数是要相等的。这个是分类问题的基本要求,本例属于分类问题。
现在数据集已经在主线程中获取到了,接下来就是网络定型阶段。