DL4J源码阅读(二):数据加载

数据加载

    本例中的数据加载是用的异步方式,所以看主线程代码,根本找不到加载数据的地方。开启异步加载数据线程是在model.fit(trainIter)这个方法中。首先判断数据集迭代器是否支持异步加载if (iterator.asyncSupported()) 。由于本例中的是RecordReaderDataSetIterator,其asyncSupported()方法是返回true的,支持异步加载。创建AsyncDataSetIterator对象,在AsyncDataSetIterator的构造方法中,创建了AsyncPrefetchThread线程,并启动。

    AsyncPrefetchThread线程启动后,操作系统会通过线程调度运行run()方法。其读取linear_data_train.csv文件的线程堆栈如下:

Daemon Thread [ADSI prefetch thread] (Suspended)

LineIterator.hasNext() line: 95

CSVRecordReader(LineRecordReader).hasNext() line: 105

CSVRecordReader.hasNext() line: 140

RecordReaderDataSetIterator.hasNext() line: 387

AsyncDataSetIterator$AsyncPrefetchThread.run() line: 412

    run()方法中,iterator开始是空的,while (iterator.hasNext() && shouldWork.get())语句中的iterator.hasNext()从文件中加载数据。最终在LineIterator.hasNext()String line = bufferedReader.readLine()语句中读到数据。smth = iterator.next()这一句是将数据从迭代器中读到数据集中,其中在CSVParserparseLine()方法中将一行数据按逗号分割为字符串数组。数据集smth放到queue中。queueBlockingQueue(接口,具体类型是LinkedBlockingQueue)的队列,这是个线程安全的阻塞队列。使用put()take(),会有阻塞效果。不是很了解BlockingQueue的同学可自行百度。

    异步线程的run()方法将数据集放到queue中后,由MultiLayerNetworkfit方法的if (!iter.hasNext() && iter.resetSupported())语句中iter.hasNext()取出。其具体位置是AsyncDataSetIteratorhasNext()方法的nextElement = buffer.take()。这里的buffer和上文提到的queue是一个变量,或者说是指向同一块内存的两个变量。在异步线程的run()方法使用put()方法将数据集放到queue。主线程中使用take()方法从buffer中获取数据。由于put()take()自身有阻塞效果,使用它们可以不用显示的写线程等待,就可以实现生产者-消费者模式。

    跟代码时发现smth = iterator.next()这一句也能调用到LineIterator.hasNext()String line = bufferedReader.readLine()。这个应该是重复了。前面已经调用过,缓存起来,后面就不用在调用了。

    一条数据记录中,区分特征和标签的依据是MLPClassifierLinearDataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2)。其中的参数0表示标签是第0列,2表示标签有两种可能值。

    RecordReaderMultiDataSetIterator类中initializeUnderlying()方法的语句builder.addOutputOneHot(READER_KEY, labelIndex, numPossibleLabels)语句设置标签相关参数。这个OneHot很重要。由于数据文件中标签只有一列,但是神经网络输出层有两个神经元。通过OneHot属性,程序可以确定另一个神经元的理想值。initializeUnderlying()方法的语句builder.addInput(READER_KEY, inputFrom, inputTo)设置特征的开始列和结束列。

    RecordReaderMultiDataSetIteratornextMultiDataSet()方法的Pair<INDArray[], INDArray[]> features = convertFeaturesOrLabels()Pair<INDArray[], INDArray[]> labels = convertFeaturesOrLabels()语句分别从数据记录中读出特征和标签(输入和理想输出)。它们回调用convertWritables()方法构造特征和标签数组,并向里面填充数据。在构造标签数组时,convertWritables()根据OneHottrue和标签有两种可能值确定标签数组的具体结构。从这里可以看出,标签的可能值数量和输出层的神经元个数是要相等的。这个是分类问题的基本要求,本例属于分类问题。

    现在数据集已经在主线程中获取到了,接下来就是网络定型阶段。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值