paddle有个参数stop_gradient。
-
结论
- 对于op里面的参数,使用stop_gradient只影响这个参数本身,即只有这个参数不更新。
- 对于op的输出,使用stop_gradient,则该点之前的所有层均不再更新。
-
解释基本想法
- 设简单的op
out = op(x)
,一般op里面的参数不会依赖于输入x,out对x的梯度不会用到out对参数的梯度。所以参数设置stop_gradient,并不影响梯度传播。 - 似乎代码实际实现时候,先计算out对x,out对所有参数的梯度。即使存在诡异的情况,out对x的梯度依赖于out对参数的梯度,也没问题,因为是先算出所有梯度,所以仍然可以算出来out对x的梯度。前面层的梯度只跟out对x的梯度有关系(已经算出来了,而且是一个数值),所以参数设置stop_gradient,并不影响梯度传播。
- 对于op的输出,使用stop_gradient。由于这个节点已经没有梯度了,即loss对out没有梯度了,所以再往之前算梯度也就没有意义了。
- 设简单的op