0.简介
有了前一篇的计算图铺垫,这次我可以计算出函数导数的结果。
1.计算导数
一般来说,计算导数的时候,我们利用导数公式可以计算,根据前一篇最开始说的,如果程序中的方程或者函数直接计算,那么这个函数的导数也需要再重写一个才可以,但是利用计算图则可以不必针对某个方程专门计算的导数,而是利用计算图的特性来计算。例如,设,,然后,计算,,对于计算图来说,计算到图的叶子节点就停止计算。实际效果如下图。
所以,利用最基本的导数运算规则,对每个结点做求导操作,最后得到的值就是导数,图会自动的根据构建方式来计算x=n时候的导数结果,导数实际对应的就是曲线一点的斜线斜率。
2.添加代码
根据上一篇继续添加计算导数函数,在其中添加back函数来求导数数值结果。
class OpNode
{
public:
virtual float forward() { return 0; }
virtual float back() { return 0; }
};
using Op = shared_ptr<OpNode>;
class Variable : public OpNode
{
public:
Variable() {}
Variable(float n):v(n) {}
float v;
static shared_ptr<Variable> make_var(float n) { return make_shared<Variable>(Variable(n)); }
float forward() { return v; }
static void set_var(shared_ptr<Variable>& a, float b){a->v = b;}
float back() { return 1.0; }
};
using Var = shared_ptr<Variable>;
class AddNode : public OpNode
{
public:
AddNode(shared_ptr<OpNode> _a, shared_ptr<OpNode> _b) :a(_a), b(_b) {}
shared_ptr<OpNode> a,b;
float forward() { return a->forward() + b->forward(); }
float back() { return a->back() + b->back(); }
};
class MulNode : public OpNode
{
public:
MulNode(shared_ptr<OpNode> _a, shared_ptr<OpNode> _b) :a(_a), b(_b) {}
shared_ptr<OpNode> a, b;
float forward() { return a->forward() * b->forward(); }
float back() { return a->back() * b->forward() + a->forward() * b->back(); }
};
class Constant : public OpNode
{
public:
float v;
Constant() {}
Constant(float n) :v(n) {}
float forward() { return v; }
float back() { return 0; }
static shared_ptr<Constant> make_var(float n) { return make_shared<Constant>(Constant(n)); }
};
你可能奇怪,为何计算个某点的导数要这么麻烦,完全可以利用,确实如此,但是后面我打算计算更神奇的东西,那么简单写法的导数计算就不能满足要求了。