一、背景信息
用户在训练脚本时往往会遇到loss不收敛或者其他精度问题,而精度问题产生的原因往往是因为梯度对参数的调整不合适或者梯度更新异常造成,针对梯度所造成的精度问题,故在此介绍一种打印模型梯度更新的方法。
二、低阶打印模型梯度示例代码段
...
...
class TrainOneStepCellPrintGrad(TrainOneStepCell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellPrintGrad, self).__init__(network, optimizer, sens=1.0)
self.print = ops.Print()
def construct(self, *inputs):
loss = self.network(*inputs)
output = self.network.output
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads)
self.print("grad:", grads)
loss = F.depend(loss, self.optimizer(grads))
return loss, output
...
...
# Construct model
model_constructed = BuildTrainNetwork(net, loss_function, TRAIN_BATCH_SIZE, CLASS_NUM)
model_constructed = TrainOneStepCellPrintGrad(model_constructed, opt)
# Train
train_net(model_constructed, net, loss_function, CHECKPOINT_MAX, EPOCH_MAX, TRAIN_PATH, VAL_PATH, TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, REPEAT_SIZE)
关键代码解析:
grads = self.grad(self.network, self.weights)(*inputs, sens)
TrainOneStepCellPrintGrad所继承的TrainOneStepCell类中定义了self.grad=mindspore.ops.GradOperation方法因此在这里直接应用该方法
二、高阶打印模型梯度示例代码段
在高阶模型中,通过改写model里build_train_network函数实现打印模型梯度功能。
···
···
class TrainOneStepCellV2(TrainOneStepCell):
'''Build train network.'''
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellV2, self).__init__(network, optimizer, sens=1.0)
# self.clip_gradients = ClipGradients()
self.print = P.Print()
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
# Obtain self.network from BuildTrainNetwork
output = self.network.output
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
# Get the gradient of the network parameters
grads = self.grad(self.network, weights)(*inputs, sens)
self.print('grad:', grads)
# grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE)
grads = self.grad_reducer(grads)
# Optimize model parameters
loss = F.depend(loss, self.optimizer(grads))
return loss, output
···
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
···
#改写build_train_network函数
···
network = TrainOneStepCellV2(network, optimizer, loss_scale).set_train()
return networkclass Model:
#改写Model中的build_train_network函数
···
def _build_train_network(self):
···
network = build_train_network(...)
...