Cell训练状态转换

神经网络中的部分Tensor操作在训练和推理时的表现并不相同,如nn.Dropout在训练时进行随机丢弃,但在推理时则不丢弃,nn.BatchNorm在训练时需要更新mean和var两个变量,在推理时则固定其值不变。因此我们可以通过Cell.set_train接口来设置神经网络的状态。

set_train(True)时,神经网络状态为train, set_train接口默认值为True:

神经网络中的部分Tensor操作在训练和推理时的表现并不相同,如nn.Dropout在训练时进行随机丢弃,但在推理时则不丢弃,nn.BatchNorm在训练时需要更新mean和var两个变量,在推理时则固定其值不变。因此我们可以通过Cell.set_train接口来设置神经网络的状态。

set_train(True)时,神经网络状态为train, set_train接口默认值为True:

set_train(False)时,神经网络状态为predict:

自定义神经网络层

通常情况下,MindSpore提供的神经网络层接口和function函数接口能够满足模型构造需求,但由于AI领域不断推陈出新,因此有可能遇到新网络结构没有内置模块的情况。此时我们可以根据需要,通过MindSpore提供的function接口、Primitive算子自定义神经网络层,并可以使用Cell.bprop方法自定义反向。下面分别详述三种自定义方法。

使用function接口构造神经网络层

MindSpore提供大量基础的function接口,可以使用其构造复杂的Tensor操作,封装为神经网络层。下面以Threshold为例,其公式如下:

昇思MindSpore学习入门-CELL与参数二_自定义

可以看到Threshold判断Tensor的值是否大于threshold值,保留判断结果为True的值,替换判断结果为False的值。因此,对应实现如下:

class Threshold(nn.Cell):

    def __init__(self, threshold, value):

        super().__init__()

        self.threshold = threshold

        self.value = value

 

    def construct(self, inputs):

        cond = ops.gt(inputs, self.threshold)

        value = ops.fill(inputs.dtype, inputs.shape, self.value)

        return ops.select(cond, inputs, value)

 

这里分别使用了ops.gt、ops.fill、ops.select来实现判断和替换。下面执行自定义的Threshold层:

m = Threshold(0.1, 20)

inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)

m(inputs)

可以看到inputs[0] = threshold, 因此被替换为20。

自定义Cell反向

在特殊场景下,我们不但需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算,此时我们可以通过Cell.bprop接口对其反向进行定义。在全新的神经网络结构设计、反向传播速度优化等场景下会用到该功能。下面我们以Dropout2d为例,介绍如何自定义Cell反向:

class Dropout2d(nn.Cell):

    def __init__(self, keep_prob):

        super().__init__()

        self.keep_prob = keep_prob

        self.dropout2d = ops.Dropout2D(keep_prob)

 

    def construct(self, x):

        return self.dropout2d(x)

 

    def bprop(self, x, out, dout):

        _, mask = out

        dy, _ = dout

        if self.keep_prob != 0:

            dy = dy * (1 / self.keep_prob)

        dy = mask.astype(mindspore.float32) * dy

        return (dy.astype(x.dtype), )

 

dropout_2d = Dropout2d(0.8)

dropout_2d.bprop_debug = True

 

bprop方法分别有三个入参:

  • x: 正向输入,当正向输入为多个时,需同样数量的入参。
  • out: 正向输出。
  • dout: 反向传播时,当前Cell执行之前的反向结果。

一般我们需要根据正向输出和前层反向结果配合,根据反向求导公式计算反向结果,并将其返回。Dropout2d的反向计算需要根据正向输出的mask矩阵对前层反向结果进行mask,然后根据keep_prob进行缩放。最终可得到正确的计算结果。

Hook功能

调试深度学习网络是每一个深度学习领域的从业者需要面对且投入精力较大的工作。由于深度学习网络隐藏了中间层算子的输入、输出数据以及反向梯度,只提供网络输入数据(特征量、权重)的梯度,导致无法准确地感知中间层算子的数据变化,从而降低了调试效率。为了方便用户准确、快速地对深度学习网络进行调试,MindSpore在动态图模式下设计了Hook功能,使用Hook功能可以捕获中间层算子的输入、输出数据以及反向梯度

目前,动态图模式下提供了四种形式的Hook功能,分别是:HookBackward算子和在Cell对象上进行注册的register_forward_pre_hook、register_forward_hook、register_backward_hook功能。

HookBackward算子

HookBackward将Hook功能以算子的形式实现。用户初始化一个HookBackward算子,将其安插到深度学习网络中需要捕获梯度的位置。在网络正向执行时,HookBackward算子将输入数据不做任何修改后原样输出;在网络反向传播梯度时,在HookBackward上注册的Hook函数将会捕获反向传播至此的梯度。用户可以在Hook函数中自定义对梯度的操作,比如打印梯度,或者返回新的梯度。

示例代码:

昇思MindSpore学习入门-CELL与参数二_数据_02

 

Cell对象的register_forward_pre_hook功能

用户可以在Cell对象上使用register_forward_pre_hook函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用@jit修饰的函数内不起作用。register_forward_pre_hook函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle对象。用户可以通过调用handle对象的remove()函数来删除与之对应的Hook函数。每一次调用register_forward_pre_hook函数,都会返回一个不同的handle对象。Hook函数应该按照以下的方式进行定义。

def forward_pre_hook_fn(cell, inputs):

    print("forward inputs: ", inputs)

这里的cell是Cell对象,inputs是正向传入到Cell对象的数据。因此,用户可以使用register_forward_pre_hook函数来捕获网络中某一个Cell对象的正向输入数据。用户可以在Hook函数中自定义对输入数据的操作,比如查看、打印数据,或者返回新的输入数据给当前的Cell对象。如果在Hook函数中对Cell对象的原始输入数据进行计算操作后,再作为新的输入数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。

示例代码:

昇思MindSpore学习入门-CELL与参数二_数据_03

用户如果在Hook函数中直接返回新创建的数据,而不是返回由原始输入数据经过计算后得到的数据,那么梯度的反向传播将会在该Cell对象上截止。

示例代码:

 

为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_forward_pre_hook 函数和 handle 对象的 remove() 函数。在动态图模式下,如果在Cell对象的 construct 函数中调用 register_forward_pre_hook 函数,那么Cell对象每次运行都将新注册一个Hook函数。

Cell对象的register_forward_hook功能

用户可以在Cell对象上使用register_forward_hook函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用@jit修饰的函数内不起作用。register_forward_hook函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle对象。用户可以通过调用handle对象的remove()函数来删除与之对应的Hook函数。每一次调用register_forward_hook函数,都会返回一个不同的handle对象。Hook函数应该按照以下的方式进行定义。

昇思MindSpore学习入门-CELL与参数二_数据_04

示例代码:

def forward_hook_fn(cell, inputs, outputs):

    print("forward inputs: ", inputs)

    print("forward outputs: ", outputs)

 

这里的cell是Cell对象,inputs是正向传入到Cell对象的数据,outputs是Cell对象的正向输出数据。因此,用户可以使用register_forward_hook函数来捕获网络中某一个Cell对象的正向输入数据和输出数据。用户可以在Hook函数中自定义对输入、输出数据的操作,比如查看、打印数据,或者返回新的输出数据。如果在Hook函数中对Cell对象的原始输出数据进行计算操作后,再作为新的输出数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。

示例代码:

昇思MindSpore学习入门-CELL与参数二_数据_05

用户如果在Hook函数中直接返回新创建的数据,而不是将原始的输出数据经过计算后,将得到的新输出数据返回,那么梯度的反向传播将会在该Cell对象上截止。该现象可以参考register_forward_pre_hook函数的用例说明。 为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的construct函数中调用register_forward_hook函数和handle对象的remove()函数。在动态图模式下,如果在Cell对象的construct函数中调用register_forward_hook函数,那么Cell对象每次运行都将新注册一个Hook函数。

Cell对象的register_backward_hook功能

用户可以在Cell对象上使用register_backward_hook函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用@jit修饰的函数内不起作用。register_backward_hook函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle对象。用户可以通过调用handle对象的remove()函数来删除与之对应的Hook函数。每一次调用register_backward_hook函数,都会返回一个不同的handle对象。

与HookBackward算子所使用的自定义Hook函数有所不同,register_backward_hook使用的Hook函数的入参中,包含了表示Cell对象名称与id信息的cell_id、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。

示例代码:

def backward_hook_function(cell_id, grad_input, grad_output):

    print(grad_input)

    print(grad_output)

这里的cell_id是Cell对象的名称以及ID信息,grad_input是网络反向传播时,传入到Cell对象的梯度,它对应于正向过程中下一个算子的反向输出梯度;grad_output是Cell对象反向输出的梯度。因此,用户可以使用register_backward_hook函数来捕获网络中某一个Cell对象的反向传入和反向输出梯度。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输出梯度。如果需要在Hook函数中返回新的输出梯度时,返回值必须是tuple的形式。

示例代码:

昇思MindSpore学习入门-CELL与参数二_反向传播_06

 

当 register_backward_hook 函数和 register_forward_pre_hook 函数、 register_forward_hook 函数同时作用于同一Cell对象时,如果 register_forward_pre_hook 和 register_forward_hook 函数中有添加其他算子进行数据处理,这些新增算子会在Cell对象执行前或者执行后参与数据的正向计算,但是这些新增算子的反向梯度不在 register_backward_hook 函数的捕获范围内。 register_backward_hook 中注册的Hook函数仅捕获原始Cell对象的输入、输出梯度。

示例代码:

昇思MindSpore学习入门-CELL与参数二_数据_07

这里的 grad_input 是梯度反向传播时传入self.relu的梯度,而不是传入 forward_hook_fn 函数中,新增的 Add 算子的梯度。这里的 grad_output 是梯度反向传播时 self.relu 反向输出的梯度,而不是 forward_pre_hook_fn 函数中新增 Add 算子的反向输出梯度。 register_forward_pre_hook 函数和 register_forward_hook 函数是在Cell对象执行前后起作用,不会影响Cell对象上反向Hook函数的梯度捕获范围。

为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct 函数中调用 register_backward_hook 函数和 handle 对象的 remove() 函数。在PyNative模式下,如果在Cell对象的 construct 函数中调用 register_backward_hook 函数,那么Cell对象每次运行都将新注册一个Hook函数。