0.简介
经过前面两篇的铺垫,这里可以请出关键了,根据写好的函数表达式,计算其导数的表达式。
1.导数表达式计算
直接看这个问题,其实比较难解决,因为导数表达式直接来算还是需要不少处理过程的,但是前面我通过计算图能求得导数的数值结果,那么,我将计算导数数值结果函数中计算数值的步骤全部都换成对应的字符串输出就可以了,这样就能看见表达式结果了,有了前面的计算图,这里就不是难题。
2.实现例举
class AddNode : public OpNode
{
public:
AddNode(shared_ptr<OpNode> _a, shared_ptr<OpNode> _b) :a(_a), b(_b) {}
shared_ptr<OpNode> a,b;
float forward() { return a->forward() + b->forward(); }
float back() { return a->back() + b->back(); }
void forwardStr()
{
cout << "(";
a->forwardStr();
cout << ")";
cout << "+";
cout << "(";
b->forwardStr();
cout << ")";
}
void backStr()
{
cout << "(";
a->backStr();
cout << ")";
cout << "+";
cout << "(";
b->backStr();
cout << ")";
}
};
上面代码中添加了backStr和forwardStr函数,这两个函数就是仿照back和forward的流程,将对应式子的字符串输出。
总体代码如下。
#include<iostream>
#include<memory>
#include<string>
using namespace std;
class OpNode
{
public:
virtual float forward() { return 0; }
virtual float back() { return 0; }
virtual void forwardStr() { }
virtual void backStr() { }
};
using Op = shared_ptr<OpNode>;
class Variable : public OpNode
{
public:
Variable() {}
Variable(float n):v(n) {}
float v;
static shared_ptr<Variable> make_var(float n) { return make_shared<Variable>(Variable(n)); }
float forward() { return v; }
static void set_var(shared_ptr<Variable>& a, float b){a->v = b;}
float back() { return 1.0; }
void forwardStr() { cout<<"x"; }
void backStr() { cout << "1"; }
};
using Var = shared_ptr<Variable>;
class AddNode : public OpNode
{
public:
AddNode(shared_ptr<OpNode> _a, shared_ptr<OpNode> _b) :a(_a), b(_b) {}
shared_ptr<OpNode> a,b;
float forward() { return a->forward() + b->forward(); }
float back() { return a->back() + b->back(); }
void forwardStr()
{
cout << "(";
a->forwardStr();
cout << ")";
cout << "+";
cout << "(";
b->forwardStr();
cout << ")";
}
void backStr()
{
cout << "(";
a->backStr();
cout << ")";
cout << "+";
cout << "(";
b->backStr();
cout << ")";
}
};
class MulNode : public OpNode
{
public:
MulNode(shared_ptr<OpNode> _a, shared_ptr<OpNode> _b) :a(_a), b(_b) {}
shared_ptr<OpNode> a, b;
float forward() { return a->forward() * b->forward(); }
float back() { return a->back() * b->forward() + a->forward() * b->back(); }
void forwardStr()
{
cout << "(";
a->forwardStr();
cout << ")";
cout << "*";
cout << "(";
b->forwardStr();
cout << ")";
}
void backStr()
{
cout << "(";
a->backStr();
cout << ")";
cout << "*";
cout << "(";
b->forwardStr();
cout << ")";
cout << "+";
cout << "(";
a->forwardStr();
cout << ")";
cout << "*";
cout << "(";
b->backStr();
cout << ")";
}
};
class Constant : public OpNode
{
public:
float v;
Constant() {}
Constant(float n) :v(n) {}
float forward() { return v; }
float back() { return 0; }
void forwardStr()
{
cout << to_string(v);
}
void backStr()
{
cout <<"0";
}
static shared_ptr<Constant> make_var(float n) { return make_shared<Constant>(Constant(n)); }
};
shared_ptr<AddNode> operator+(shared_ptr<OpNode> a, shared_ptr<OpNode> b)
{
return make_shared<AddNode>(AddNode(a,b));
}
shared_ptr<MulNode> operator*(shared_ptr<OpNode> a, shared_ptr<OpNode> b)
{
return make_shared<MulNode>(MulNode(a, b));
}
shared_ptr<MulNode> operator*(float a, shared_ptr<OpNode> b)
{
return make_shared<MulNode>(MulNode(Constant::make_var(a), b));
}
shared_ptr<MulNode> operator*(shared_ptr<OpNode> a, float b)
{
return make_shared<MulNode>(MulNode(a, Constant::make_var(b)));
}
int main()
{
//声明变量
Var x = Variable::make_var(0.0);
//构建表达式
Op y = x * x + 2*x;
y->forwardStr();
cout << endl;
y->backStr();
return 0;
}
实际的执行结果如下。