Tensorflow.js 大规模训练技巧:用 Generator 分批读数据
本文将会介绍使用 Tensorflow.js 进行大规模训练时候的一个技巧:分批读取训练数据。我们将会用:
- JS 的 Generator 函数
- Tensorflow.js 的 tf.data.generator 和 model.fitDataset 两个 API
来实现分批训练。
前置知识
在阅读本文前,请确保你拥有以下知识,否则读起来可能比较吃力:
- 机器学习/深度学习基础,推荐阅读:机器学习速成教程。
- 基于 Tensorflow.js 的小规模训练经验,推荐阅读:Tensorflow.js 官方教程。
背景问题:线性增长的内存占用
对于深度学习来说,训练集越大,模型学习的效果越好。所以,在真实的工作中,我们的训练集往往会上 G,这时候,如果把所有训练集都读到内存中,那么还没等到训练,计算机的内存就吃不消,被爆掉了。
比如,你有这样一段代码,将所有训练集一次性读取到内存中:
// trainSet 是一个数组,里面包含了所有的训练集的硬盘路径信息和分类信息。
// getTensor 函数用来读取训练集到内存,并生成 tensor
const {
xs, ys } = getTensor(trainSet);
await model.fit(xs, ys, {
epochs: 20 });