C++ 将数学表达式构建成图
flyfish
调试环境VC2017
构建成图之后,再计算表达式的值
原文是
How To Write Your Own Tensorflow in C++
如何使用C++编写自己的Tensorflow
本文是将数学表达式构建成图,再计算表达式的值,部分代码摘抄,为了容易理解,更改其代码
例如将以下表达式构建成图
t1=1+2
t2=t1*3
t3=t2+exp(4)
用下面BuildGraph函数会构建这样的图
主函数调用代码
#include "stdafx.h"
#include "GraphDef.h"
int main()
{
ai::NodeDef t1 = 12;
ai::NodeDef t2 = t1 + ai::poly(2, 3);
ai::NodeDef t3(10);
ai::NodeDef t4 = t3 * 11+ t1* t2;
ai::GraphDef g(t4);
double a=g.eval();
return 0;
}
NodeDef.h 头文件
#pragma once
#include <iostream>
#include <vector>
#include <memory>
#include <utility>
namespace ai
{
class NodeDef;
enum class OpDef {
plus,// operator+
minus,// operator-
multiply,// operator*
divide,// operator/
exponent,// exp() // e^x
polynomial,// poly() // x^n 指数幂 在a^n中,a叫做底数,n叫做指数。a^n读作“a的n次方”或“a的n次幂“。
none // no operators. leaf.
};
int numOpArgs(OpDef op);
}
namespace std
{
template <> struct hash<ai::NodeDef> {
size_t operator()(const ai::NodeDef&) const;
};
}
namespace ai
{
class NodeDef
{
struct impl;
public:
NodeDef(std::shared_ptr<impl>);
NodeDef(double);
NodeDef(OpDef, const std::vector<NodeDef>&);
~NodeDef();
NodeDef(NodeDef&&) noexcept;
NodeDef& operator=(NodeDef&&) noexcept;
//浅拷贝
NodeDef(const NodeDef&);
NodeDef& operator=(const NodeDef&);
//深拷贝
NodeDef Clone();
double GetValue() const;
void SetValue(double);
OpDef GetOp() const;
void SetOp(OpDef);
std::vector<NodeDef>& GetChildren() const;
std::vector<NodeDef> GetParents() const;
bool operator==(const NodeDef& rhs) const;
friend struct std::hash<NodeDef>;
template <typename... V>
friend const NodeDef BuildGraph(OpDef, V&...);
private:
// PImpl idiom :
std::shared_ptr<impl> pimpl;
};
struct NodeDef::impl
{
public:
impl(double);
impl(OpDef, const std::vector<NodeDef>&);
double val;
OpDef op;
std::vector<NodeDef> children;
std::vector<std::weak_ptr<impl>> parents;
};
template <typename... V>
const NodeDef BuildGraph(OpDef op, V&... args){
std::vector<std::shared_ptr<NodeDef::impl> > vimpl = { args.pimpl... };
std::vector<NodeDef> v;
for(const std::shared_ptr<NodeDef::impl>& _impl : vimpl){
v.emplace_back(_impl);
}
NodeDef res(op, v);
for(const std::shared_ptr<NodeDef::impl>& _impl : vimpl){
_impl->parents.push_back(res.pimpl);
}
return res;
}
inline const NodeDef operator+(NodeDef lhs, NodeDef rhs){
return BuildGraph(OpDef::plus, lhs, rhs);
}
inline const NodeDef operator-(NodeDef lhs, NodeDef rhs){
return BuildGraph(OpDef::minus, lhs, rhs);
}
inline const NodeDef operator*(NodeDef lhs, NodeDef rhs){
return BuildGraph(OpDef::multiply, lhs, rhs);
}
inline const NodeDef operator/(NodeDef lhs, NodeDef rhs){
return BuildGraph(OpDef::divide, lhs, rhs);
}
inline const NodeDef exp(NodeDef v){
return BuildGraph(OpDef::exponent, v);
}
inline const NodeDef poly(NodeDef v, NodeDef power){
NodeDef p(power);
return BuildGraph(OpDef::polynomial, v, p);
}
}
NodeDef.cpp 实现文件
#include "stdafx.h"
#include "NodeDef.h"
#include <map>
namespace ai
{
int numOpArgs(OpDef op)
{
static const std::map<OpDef, int> op_args = {
{ OpDef::plus, 2 },
{ OpDef::minus, 2 },
{ OpDef::multiply, 2 },
{ OpDef::divide, 2 },
{ OpDef::exponent, 1 },
{ OpDef::polynomial, 1 },
{ OpDef::none, 0 },
};
return op_args.find(op)->second;
};
NodeDef::NodeDef(NodeDef&&) noexcept = default;
NodeDef& NodeDef::operator=(NodeDef&&) noexcept = default;
NodeDef::~NodeDef() = default;
NodeDef::NodeDef(const NodeDef&) = default;
NodeDef& NodeDef::operator=(const NodeDef&) = default;
NodeDef NodeDef::Clone()
{
return NodeDef(std::make_shared<impl>(*pimpl));
}
NodeDef::NodeDef(std::shared_ptr<impl> _pimpl) : pimpl(_pimpl) {};
NodeDef::NodeDef(double _val)
: pimpl(new impl(_val)) {}
NodeDef::NodeDef(OpDef _op, const std::vector<NodeDef>& _children)
: pimpl(new impl(_op, _children)) {}
/* Getters and Setters */
double NodeDef::GetValue() const { return pimpl->val; }
void NodeDef::SetValue(double _val) { pimpl->val = _val; }
OpDef NodeDef::GetOp() const { return pimpl->op; }
void NodeDef::SetOp(OpDef _op) { pimpl->op = _op; }
std::vector<NodeDef>& NodeDef::GetChildren() const { return pimpl->children; }
std::vector<NodeDef> NodeDef::GetParents() const {
std::vector<NodeDef> _parents;
for (std::weak_ptr<impl> parent : pimpl->parents) {
_parents.emplace_back(parent.lock());
}
return _parents;
}
bool NodeDef::operator==(const NodeDef& rhs) const
{
return pimpl.get() == rhs.pimpl.get();
}
NodeDef::impl::impl(double _val) :
val(_val),
op(OpDef::none) {}
NodeDef::impl::impl(OpDef _op, const std::vector<NodeDef>& _children)
: op(_op) {
for (const NodeDef& v : _children) {
children.emplace_back(v.pimpl);
}
}
}
namespace std
{
size_t hash<ai::NodeDef>::operator()(const ai::NodeDef& v) const
{
return std::hash<std::shared_ptr<ai::NodeDef::impl> >{}(v.pimpl);
}
}
GraphDef.h 头文件
#pragma once
#include "NodeDef.h"
#include <queue>
#include <unordered_map>
#include <unordered_Set>
namespace ai
{
class GraphDef
{
public:
GraphDef(NodeDef);
NodeDef GetRoot() const;
double propagate();
double propagate(const std::vector<NodeDef>& leaves);
private:
NodeDef root;
public:
double eval();//计算值
};
}
GraphDef.cpp 实现文件
#include "stdafx.h"
#include "GraphDef.h"
#include <cmath>
#include <exception>
namespace ai
{
double _eval(OpDef op, const std::vector<NodeDef>& operands)
{
switch(op){
case OpDef::plus:
return operands[0].GetValue() + operands[1].GetValue();
case OpDef::minus:
return operands[0].GetValue() - operands[1].GetValue();
case OpDef::multiply:
return operands[0].GetValue() * operands[1].GetValue();
case OpDef::divide:
return operands[0].GetValue() / operands[1].GetValue();
case OpDef::exponent:
return std::exp(operands[0].GetValue());
case OpDef::polynomial:
return std::pow(operands[0].GetValue(), operands[1].GetValue());
case OpDef::none:
throw std::invalid_argument("_eval invalid argument");
};
return 0.0;
}
GraphDef::GraphDef(NodeDef _root) : root(_root){}
NodeDef GraphDef::GetRoot()const
{
return root;
}
void _rpropagate(NodeDef& v)
{
if(v.GetChildren().empty())
return;
std::vector<NodeDef> children = v.GetChildren();
for(NodeDef& _v : children){
_rpropagate(_v);
}
v.SetValue(_eval(v.GetOp(), v.GetChildren()));
}
double GraphDef::propagate()
{
_rpropagate(root);
return root.GetValue();
}
double GraphDef::propagate(const std::vector<NodeDef>& leaves)
{
std::queue<NodeDef> q;
std::unordered_map<NodeDef, int> explored;
for(const NodeDef& v : leaves)
q.push(v);
while(!q.empty())
{
NodeDef v = q.front();
q.pop();
std::vector<NodeDef> parents = v.GetParents();
for(NodeDef& parent : parents)
{
explored[parent]++;
if(numOpArgs(parent.GetOp()) == explored[parent])
{
parent.SetValue(_eval(parent.GetOp(), parent.GetChildren()));
q.push(parent);
}
}
}
return root.GetValue();
}
double GraphDef::eval()
{
return propagate();
}
}