from mxnet import autograd, nd
方法
x = nd.arange(4).reshape((4, 1))
方法 |
结果 |
x.attach_grad() |
申请求梯度所需内存 |
autograd.record() |
用于上下文管理器,管理内容是待求梯度 |
y.backward() |
自动求梯度 |
autograd.is_training() |
查看是否处于训练模式 |
这里求梯度的方便之处式,即使函数中又python的控制语句存在,依然可以求梯度。比如下面的代码:
from mxnet import autograd, nd
def f(a):
while a.norm