概述
本教程介绍梯度累积的训练方式,目的是为了解决由于内存不足导致某些大型网络无法训练大Batch_size的问题。
传统的训练方式是每次计算得到loss和梯度后,直接用所得梯度对参数进行更新。
与传统的训练方式不同,梯度累积引入Mini-batch的概念,首先对每个Mini-batch的数据计算loss和梯度,但不立即更新模型参数,而是先对所得梯度进行累加,然后在指定数量(N)个Mini-batch之后,用累积后的梯度更新网络参数。下次训练前清空过往累积梯度后重新累加,如此往复。最终目的是为了达到跟直接用N*Mini-batch数据训练几乎同样的效果。
本篇教程将分别介绍在单机模式和并行模式下如何实现梯度累积训练。
梯度下降算法大致分为三种
1. 批量梯度下降(Batch Gradient Descent,BGD)
2. 批量梯度下降(Batch Gradient Descent,BGD)
3. 小批量梯度下降(Mini-Batch Gradient Descent,MBGD)
单机模式
在单机模式下,主要通过将训练流程拆分为正向反向训练、参数更新和累积梯度清理三个部分实现梯度累积。这里以MNIST作为示范数据集,自定义简单模型实现梯度累积需要如下几个步骤。
首先要导入需要的文库:
代码如下:
import argparse