背景信息
在混合精度中,使用float16类型来替代float32类型存储数据,从而达到减少内存和提高计算速度的效果。但是由于float16类型要比float32类型表示的范围小很多,所以当某些参数(比如说梯度)在训练过程中变得很小时,就会发生数据下溢的情况,进而影响网络精度。而loss scale正是为了解决float16类型数据下溢问题的,loss scale的主要思想是在计算loss时,将loss扩大一定的倍数,由于链式法则的存在,梯度也会相应扩大,然后在优化器更新权重时再缩小相应的倍数,从而避免了数据下溢的情况又不影响计算结果。
在MindSpore中,loss scale的使用方法又分动态loss scale和静态loss scale两种,二者具体区别详见静态LossScale和动态LossScale的区别。
在低阶API使用动态LossScale功能时,通常结合TrainOneStepWithLossScaleCell来实现。
1、示例代码段
**步骤:**在使用动态loss scale功能前,首先要定义一张网络、损失函数、优化器,然后将损失函数融入到神经网络中,再使用动态loss scale功能
from src.new import TrainOneStepWithLossScaleCell
# Define network
net = resnet()
# Define the loss_function
loss_function = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# Define the optimizer
opt = nn.SGD(net.trainable_params(), LR_ORI, MOMENTUM_ORI, WEIGHT_DECAY)
# Bind loss_function to net
model_constructed = BuildTrainNetwork(net, loss_function, TRAIN_BATCH_SIZE, CLASS_NUM)
# Define the Dynamic Loss scale update cell
loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
# Define Network training with loss scaling
model_constructed = TrainOneStepWithLossScaleCell(model_constructed, opt, scale_sense=loss_scale_manager)
# Train
train_net(model_constructed, net, loss_function, EPOCH_MAX, TRAIN_PATH, VAL_PATH, TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, REPEAT_SIZE)
2、代码解析
loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
**功能:**定义一个动态loss scale的控制器。指定loss scale的初始值为loss_scale_value。当训练发生上溢出时,将执行loss_scale_value=loss_scale_value/scale_factor。如果连续scale_window个step没有发生上溢出,则执行loss_scale_value=loss_scale_value*scale_factor。
接口参数详解:
loss_scale_value :loss scale的初始值,float数据类型;
scale_factor:loss_scale的调整系数,int数据类型;
scale_window:如果连续scale_window个step没有发生溢出,则执行loss_scale_value=loss_scale_value*scale_factor,int数据类型 ;
model_constructed = TrainOneStepWithLossScaleCell(network=model_constructed, optimizer=opt, scale_sense=loss_scale_manager)
功能:
此接口封装了整网的训练步骤,它接受网络、优化器和loss scale控制器作为入参。
接口参数详解:
network:训练网络;
optimizer:优化器;
scale_sense:loss scale控制器;