背景信息
用户在使用深度学习网络结构执行开发任务时,难免会遇到一些代码上看上去很正常,而且代价函数也在不断减小,但是因为不知道的bug存在,使得我们得到的神经网络并不是最优解。为详细排查所遇问题,故在此提供2种不同接口用以实现梯度不回传以及梯度回传后不更新权重功能。
一、使用stop_gradient接口实现
1、示例代码
import numpy as np
import mindspore.ops as ops
from mindspore import Tensor, context
from mindspore.ops
import operations as Pimport mindspore.nn as nn
import mindsporefrom mindspore.ops import stop_gradient
#设置训练环境
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class PrintDemo(nn.Cell):
def __init__(self):
super(PrintDemo, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=4, stride=1 ,has_bias=False, weight_init='normal', pad_mode='valid')
self.conv2 = nn.Conv2d(in_channels=6, out_channels=2, kernel_size=2, pad_mode="valid")
self.conv3 = nn.Conv2d(in_channels=2, out_channels=6, kernel_size=2, pad_mode="valid")
self.print = P.Print()
#打印出特定层权重输出结果
def construct(self, input_data):
x = self.conv1(input_data)
x = stop_gradient(x)
self.print("self.conv1.weight:", self.conv1.weight)
x = self.conv2(x)
x = self.conv3(x)
return x
def test():
input_data = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
net = PrintDemo()
net(input_data)
return net(input_data)
test()
参数注解:
stop_gradient():在反向传播时禁止网络梯度更新。
二、使用requires_grad接口实现
1、示例代码
class PrintDemo(nn.Cell):
#打印出特定层权重输出结果
def construct(self, input_data):
...
# x = stop_gradient(x)
...
def test():
...
for param in net.trainable_params():
if 'conv1' in param.name:
param.requires_grad = False
else:
param.requires_grad = True
...
test()
参数注解:
requires_grad:bool类型,当值为True时表面该参数需要更新,反之则不需更新。
三、实验截图
梯度更新的中心思想沿着loss函数梯度的方向更新权重以让loss函数的值最小化或accuracy最大化,在示例代码中使用requires_grad与stop_gradient方法实现禁止conv1层梯度更新。实验中conv1层中开始与结束时的权重变化如下图所示:
conv1中权重值的变化: