分批读取数据_Tensorflow.js 大规模训练技巧:用 Generator 分批读数据

Tensorflow.js 大规模训练技巧:用 Generator 分批读数据

本文将会介绍使用 Tensorflow.js 进行大规模训练时候的一个技巧:分批读取训练数据。我们将会用:

  • JS 的 Generator 函数
  • Tensorflow.js 的 tf.data.generator 和 model.fitDataset 两个 API

来实现分批训练。

前置知识

在阅读本文前,请确保你拥有以下知识,否则读起来可能比较吃力:

  • 机器学习/深度学习基础,推荐阅读:机器学习速成教程。
  • 基于 Tensorflow.js 的小规模训练经验,推荐阅读:Tensorflow.js 官方教程。

背景问题:线性增长的内存占用

对于深度学习来说,训练集越大,模型学习的效果越好。所以,在真实的工作中,我们的训练集往往会上 G,这时候,如果把所有训练集都读到内存中,那么还没等到训练,计算机的内存就吃不消,被爆掉了。

比如,你有这样一段代码,将所有训练集一次性读取到内存中:

// trainSet 是一个数组,里面包含了所有的训练集的硬盘路径信息和分类信息。
// getTensor 函数用来读取训练集到内存,并生成 tensor
const {
     xs, ys } = getTensor(trainSet);

await model.fit(xs, ys, {
     epochs: 20 });
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值