C++ 将数学表达式构建成图

C++ 将数学表达式构建成图

flyfish

调试环境VC2017

构建成图之后,再计算表达式的值

原文是
How To Write Your Own Tensorflow in C++
如何使用C++编写自己的Tensorflow

本文是将数学表达式构建成图,再计算表达式的值,部分代码摘抄,为了容易理解,更改其代码

例如将以下表达式构建成图

t1=1+2
t2=t1*3
t3=t2+exp(4)

用下面BuildGraph函数会构建这样的图
这里写图片描述

主函数调用代码

#include "stdafx.h"
#include "GraphDef.h"
int main()
{
    ai::NodeDef t1 = 12;
    ai::NodeDef t2 = t1 + ai::poly(2, 3);
    ai::NodeDef t3(10);
    ai::NodeDef t4 = t3 * 11+ t1* t2;

    ai::GraphDef g(t4);
    double a=g.eval();  
    return 0;
}

NodeDef.h 头文件

#pragma once

#include <iostream>
#include <vector>
#include <memory>
#include <utility>

namespace ai
{

class NodeDef;

enum class OpDef {
    plus,// operator+
    minus,// operator-
    multiply,// operator*
    divide,// operator/
    exponent,// exp() // e^x
    polynomial,// poly() // x^n 指数幂 在a^n中,a叫做底数,n叫做指数。a^n读作“a的n次方”或“a的n次幂“。
    none // no operators. leaf.
};

int numOpArgs(OpDef op);

}

namespace std
{

template <> struct hash<ai::NodeDef> {
    size_t operator()(const ai::NodeDef&) const;
};
}

namespace ai
{


class NodeDef
{

struct impl;

public:

    NodeDef(std::shared_ptr<impl>);

    NodeDef(double);
    NodeDef(OpDef, const std::vector<NodeDef>&);
    ~NodeDef();


    NodeDef(NodeDef&&) noexcept;
    NodeDef& operator=(NodeDef&&) noexcept;

    //浅拷贝
    NodeDef(const NodeDef&);
    NodeDef& operator=(const NodeDef&);

    //深拷贝
    NodeDef Clone();


    double GetValue() const;
    void SetValue(double);
    OpDef GetOp() const;
    void SetOp(OpDef);


    std::vector<NodeDef>& GetChildren() const;
    std::vector<NodeDef> GetParents() const;


    bool operator==(const NodeDef& rhs) const;
    friend struct std::hash<NodeDef>;

    template <typename... V>
    friend const NodeDef BuildGraph(OpDef, V&...);
private: 
    // PImpl idiom :
    std::shared_ptr<impl> pimpl;
};

struct NodeDef::impl
{
public:

    impl(double);
    impl(OpDef, const std::vector<NodeDef>&);
    double val;
    OpDef op; 
    std::vector<NodeDef> children;

    std::vector<std::weak_ptr<impl>> parents;
};


template <typename... V>
const NodeDef BuildGraph(OpDef op, V&... args){
    std::vector<std::shared_ptr<NodeDef::impl> > vimpl = { args.pimpl... };
    std::vector<NodeDef> v;
    for(const std::shared_ptr<NodeDef::impl>& _impl : vimpl){
        v.emplace_back(_impl); 
    }
    NodeDef res(op, v);
    for(const std::shared_ptr<NodeDef::impl>& _impl : vimpl){
        _impl->parents.push_back(res.pimpl);
    }
    return res;
}


inline const NodeDef operator+(NodeDef lhs, NodeDef rhs){
    return BuildGraph(OpDef::plus, lhs, rhs);
}

inline const NodeDef operator-(NodeDef lhs, NodeDef rhs){
    return BuildGraph(OpDef::minus, lhs, rhs);
}

inline const NodeDef operator*(NodeDef lhs, NodeDef rhs){
    return BuildGraph(OpDef::multiply, lhs, rhs);
}

inline const NodeDef operator/(NodeDef lhs, NodeDef rhs){
    return BuildGraph(OpDef::divide, lhs, rhs);
}

inline const NodeDef exp(NodeDef v){
    return BuildGraph(OpDef::exponent, v);
}

inline const NodeDef poly(NodeDef v, NodeDef power){
    NodeDef p(power);
    return BuildGraph(OpDef::polynomial, v, p);
}

}

NodeDef.cpp 实现文件

#include "stdafx.h"
#include "NodeDef.h"
#include <map>

namespace ai
{


    int numOpArgs(OpDef op)
    {
        static const std::map<OpDef, int> op_args = {
            { OpDef::plus, 2 },
            { OpDef::minus, 2 },
            { OpDef::multiply, 2 },
            { OpDef::divide, 2 },
            { OpDef::exponent, 1 },
            { OpDef::polynomial, 1 },
            { OpDef::none, 0 },
        };
        return op_args.find(op)->second;
    };


    NodeDef::NodeDef(NodeDef&&) noexcept = default;
    NodeDef& NodeDef::operator=(NodeDef&&) noexcept = default;
    NodeDef::~NodeDef() = default;
    NodeDef::NodeDef(const NodeDef&) = default;
    NodeDef& NodeDef::operator=(const NodeDef&) = default;
    NodeDef NodeDef::Clone()
    {
        return NodeDef(std::make_shared<impl>(*pimpl));
    }

    NodeDef::NodeDef(std::shared_ptr<impl> _pimpl) : pimpl(_pimpl) {};

    NodeDef::NodeDef(double _val)
        : pimpl(new impl(_val)) {}

    NodeDef::NodeDef(OpDef _op, const std::vector<NodeDef>& _children)
        : pimpl(new impl(_op, _children)) {}

    /* Getters and Setters */
    double NodeDef::GetValue() const { return pimpl->val; }

    void NodeDef::SetValue(double _val) { pimpl->val = _val; }

    OpDef NodeDef::GetOp() const { return pimpl->op; }

    void NodeDef::SetOp(OpDef _op) { pimpl->op = _op; }

    std::vector<NodeDef>& NodeDef::GetChildren() const { return pimpl->children; }

    std::vector<NodeDef> NodeDef::GetParents() const {
        std::vector<NodeDef> _parents;
        for (std::weak_ptr<impl> parent : pimpl->parents) {
            _parents.emplace_back(parent.lock());
        }
        return _parents;
    }



    bool NodeDef::operator==(const NodeDef& rhs) const
    {
        return pimpl.get() == rhs.pimpl.get();
    }

    NodeDef::impl::impl(double _val) :
        val(_val),
        op(OpDef::none) {}

    NodeDef::impl::impl(OpDef _op, const std::vector<NodeDef>& _children)
        : op(_op) {
        for (const NodeDef& v : _children) {
            children.emplace_back(v.pimpl);
        }
    }

}

namespace std
{
    size_t hash<ai::NodeDef>::operator()(const ai::NodeDef& v) const
    {
        return std::hash<std::shared_ptr<ai::NodeDef::impl> >{}(v.pimpl);
    }
}

GraphDef.h 头文件

#pragma once

#include "NodeDef.h"
#include <queue>
#include <unordered_map>
#include <unordered_Set>

namespace ai
{
class GraphDef 
{
public:
    GraphDef(NodeDef);
    NodeDef GetRoot() const;
    double propagate();
    double propagate(const std::vector<NodeDef>& leaves);

private:
    NodeDef root;
public:
    double eval();//计算值
};

}

GraphDef.cpp 实现文件

#include "stdafx.h"
#include "GraphDef.h"
#include <cmath>
#include <exception>

namespace ai
{
double _eval(OpDef op, const std::vector<NodeDef>& operands)
{
    switch(op){
        case OpDef::plus:
            return operands[0].GetValue() + operands[1].GetValue();
        case OpDef::minus:
            return operands[0].GetValue() - operands[1].GetValue();
        case OpDef::multiply:
            return operands[0].GetValue() * operands[1].GetValue();
        case OpDef::divide:
            return operands[0].GetValue() / operands[1].GetValue();
        case OpDef::exponent:
            return std::exp(operands[0].GetValue());
        case OpDef::polynomial:
            return std::pow(operands[0].GetValue(), operands[1].GetValue());
        case OpDef::none:
            throw std::invalid_argument("_eval invalid argument");
    }; 

    return 0.0;
}


GraphDef::GraphDef(NodeDef _root) : root(_root){}

NodeDef GraphDef::GetRoot()const
{
    return root;
}

void _rpropagate(NodeDef& v)
{
    if(v.GetChildren().empty())
        return;
    std::vector<NodeDef> children = v.GetChildren(); 
    for(NodeDef& _v : children){
        _rpropagate(_v);
    }
    v.SetValue(_eval(v.GetOp(), v.GetChildren()));
}

double GraphDef::propagate()
{
    _rpropagate(root);
    return root.GetValue();
}

double GraphDef::propagate(const std::vector<NodeDef>& leaves)
{
    std::queue<NodeDef> q;
    std::unordered_map<NodeDef, int> explored; 
    for(const NodeDef& v : leaves)
        q.push(v);

    while(!q.empty())
    {
        NodeDef v = q.front();
        q.pop();
        std::vector<NodeDef> parents = v.GetParents();
        for(NodeDef& parent : parents)
        {
            explored[parent]++; 
            if(numOpArgs(parent.GetOp()) == explored[parent])
            {
                parent.SetValue(_eval(parent.GetOp(), parent.GetChildren()));
                q.push(parent);
            }
        } 
    } 
    return root.GetValue();
}

double GraphDef::eval()
{
    return propagate();
}

}
使用boost::sprirt编写的表达式解析器,代码很容易扩展,功能很强大,适合做功能强大的客户化定义程序。 -----------表达式解析--------- 已定义的函数有:PI,SIN,COS,TAN,,ABS,EXP,LOGN,POW,SQRT,FORMAT,LENGTH,SUBSTR 强制类型转换请使用:(数据类型)数据 例如:(int) "100" 的值为int型100 已定义的变量有:var1=123,var2=this is a string 请输入您的表达式>>1+2 计算结果:3 XML格式显示计算过程: <?xml version="1.0" encoding="ISO-8859-1"?> <!DOCTYPE parsetree SYSTEM "parsetree.dtd"> <!-- 1+2 --> <parsetree version="1.0"> <parsenode rule="add_expr"> <value>+</value> <parsenode rule="integer_const"> <value>1</value> </parsenode> <parsenode rule="integer_const"> <value>2</value> </parsenode> </parsenode> </parsetree> 计算结果数据类型:integer 计算结果:3 请输入您的表达式>>"ab"+"cd" 计算结果:"abcd" XML格式显示计算过程: <?xml version="1.0" encoding="ISO-8859-1"?> <!DOCTYPE parsetree SYSTEM "parsetree.dtd"> <!-- "ab"+"cd" --> <parsetree version="1.0"> <parsenode rule="add_expr"> <value>+</value> <parsenode rule="string_const"> <value>"ab"</value> </parsenode> <parsenode rule="string_const"> <value>"cd"</value> </parsenode> </parsenode> </parsetree> 计算结果数据类型:string 计算结果:abcd 请输入您的表达式>>format("1+2=%d",1+2) 计算结果:format("1+2=%d",3) XML格式显示计算过程: <?xml version="1.0" encoding="ISO-8859-1"?> <!DOCTYPE parsetree SYSTEM "parsetree.dtd"> <!-- format("1+2=%d",1+2) --> <parsetree version="1.0"> <parsenode rule="function_identifier"> <value>format</value> <parsenode rule="exprlist"> <parsenode rule="string_const"> <value>"1+2=%d"</value> </parsenode> <parsenode rule="add_expr"> <value>+</value> <parsenode rule="integer_const"> <value>1</value> </parsenode> <parsenode rule="integer_const"> <value>2</value> </parsenode> </parsenode> </parsenode> </parsenode> </parsetree> 计算结果数据类型:string 计算结果:1+2=3 请输入您的表达式>>PI() 计算结果:PI() XML格式显示计算过程: <?xml version="1.0" encoding="ISO-8859-1"?> <!DOCTYPE parsetree SYSTEM "parsetree.dtd"> <!-- PI() --> <parsetree version="1.0"> <parsenode rule="function_identifier"> <value>PI</value> </parsenode> </parsetree> 计算结果数据类型:double 计算结果:3.1415926535897931 请输入您的表达式>>cos(1.5) 计算结果:cos(1.5) XML格式显示计算过程: <?xml version="1.0" encoding="ISO-8859-1"?> <!DOCTYPE parsetree SYSTEM "parsetree.dtd"> <!-- cos(1.5) --> <parsetree version="1.0"> <parsenode rule="function_identifier"> <value>cos</value> <parsenode rule="double_const"> <value>1.5</value> </parsenode> </parsenode> </parsetree> 计算结果数据类型:double 计算结果:0.070737201667702906 请输入您的表达式>>q
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

西笑生

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值