MXNet中x.grad源码追溯

Python测试代码如https://zh.gluon.ai/chapter_prerequisite/autograd.html

本文追溯x.grad这一行代码的调用

grad调用的是函数MXNDArrayGetGrad,/usr/local/lib/python3.7/dist-packages/mxnet-1.5.0-py3.7.egg/mxnet/ndarray/ndarray.py

MXNDArrayGetGrad的源码依旧是在文件src/c_api/c_api.cc中,

NDArray ret = arr->grad();

ret就是获取到的梯度

这里grad的源码文件为src/ndarray/ndarray.cc,

Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);

return info.out_grads[0];

这里Imperative::AGInfo::Get的源码文件为 include/mxnet/imperative.h

return dmlc::get<AGInfo>(node->info);

这里get的源码文件为3rdparty/dmlc-core/include/dmlc/any.h

return *any::TypeInfo<T>::get_ptr(&(src.data_));

这个get_ptr调用的是同文件中的如下代码:

template<typename T>
class any::TypeOnHeap {
 public:
  inline static T* get_ptr(any::Data* data) {
    return static_cast<T*>(data->pheap);
  }

回到上面的代码,那个entry_是NDArrary类的一个对象:

  /*! \brief node entry for autograd */
  nnvm::NodeEntry entry_;

NodeEntry 源码文件为include/nnvm/node.h,

#大体来讲,梯度就是arr->entry_.node->info.data_.pheap;

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值