C++解析决策树PMML文档

PMML三问

PMML是什么?

预测模型标记语言(Predictive Model Markup Language,PMML)一种可以呈现预测分析模型的事实标准语言。

PMML价值是什么

实现模型的跨语言部署。举个例子,模型是用python训练但是希望部署到JAVA或者C++环境中,解决方案是将模型以PMML格式文件导出,然后利用其他语言进行解析、部署

PMML如何解析

JAVA语言中有JPMML工具、Python也有相关的工具;唯独C++没有找到,只能手写了。

接下来将会贴代码

解析过程

C++解析PMML文件中模型

本实例中用到的相关技术:tinyxml2(c++中解析XML的一个工具,PMML可以认为是XML中的一种特殊格式),具体使用可自行百度

                                            C++(解析语言)、(JAVA、JPMML工具,用来测试解析输出的结果是否正确)

                                            Python、pydotplus、PMMLPipeline(用Iris数据集训练一颗决策树,然后以PMML文件格式导出至本地)

 

Python导出PMML文件代码

from sklearn2pmml import PMMLPipeline
from sklearn.datasets import load_iris
from sklearn import tree
import graphviz

iris = load_iris()
clf = tree.DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris.data, iris.target)

clf.fit(iris.data, iris.target)
# 导出为PMML
from sklearn2pmml import sklearn2pmml
#sklearn2pmml(pipeline, "/Users/hzp/Desktop/DecisionTreeIris.pmml", with_repr = True)



import pydotplus

with open('/Users/hzp/Desktop/treeone.dot', 'w') as f:
      dot_data = tree.export_graphviz(clf, out_file=None)
      f.write(dot_data)
 
  #画图方法2-生成pdf文件
dot_data = tree.export_graphviz(clf, out_file=None,feature_names=clf.feature_importances_,
                                  filled=True, rounded=True, special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
 ##保存图像到pdf文件
graph.write_pdf("/Users/hzp/Desktop/treetwo.pdf")

JAVA解析PMML文件

参考我另外一篇博客

https://blog.csdn.net/zehui6202/article/details/105074143

伪代码思路

 
//递归遍历PMML中决策树预规则部分的节点
while(node){
      if (Nodename == "Node" and has_attribute){
         //满足条件认为是叶子节点
         label = score //叶子节点的标签属性
         Node -> Node->FirstChildren//指向第一个子节点
         if (input_feature 满足条件 or Node == "True"){
             return label
        }
        
        else{
                node -> NextNode // 指向下一个节点(当前是在子节点,是要指向父节点的下一个节点)
                  }
     }
    
     else if (Node == "Node" and has_not_attribute){
          node -> Node->FirstChildren
          if (input_feature 满足条件){
              node -> Firstchildren
        }
         
          elif if{ 
                node -> NextNode
               }
     }
}

C++解析决策树PMML文件

功能:输入一个特征,输出特征所属类别
输入:特征名、特征属性
输入格式:map
输出:特征被分类的label
 
附:也能解析xml的字符串,代码中有例子
#include <iostream>
#include "tinyxml2/simple.h"
#include <map>

#include<algorithm>
#include<string>
#include <sstream>


using namespace std;
using namespace tinyxml2;



//const char* xmlpathError="/Users/hzp/Downloads/TinyXML2-simple-master/errorExample.xml";



const char* xmlpathError="/Users/hzp/Desktop/DecisionTreeIris.pmml";
void ParserXMLFile();


float int2str(string num){
    float res;
    stringstream stream(num);
    stream>>res;
    return res;
}




void ParserXMLFile(int &depth,XMLNode *pNode,map<string,float >  input){
    // 根节点
    XMLNode *node=pNode;   //指针指向节点初始位置

   //单个节点解析

   // 如果节点是元素
   string label = "";

   while(node) {
       if (node->ToElement()) {
           XMLElement *element = node->ToElement();
           cout.width(depth);
           string NodeName = element->Name();    //指向节点名
           const XMLAttribute *attribute = element->FirstAttribute(); //获取节点第一个属性

           //取标签
           if (NodeName == "Node" and attribute) {
               //将所有属性聚合,输入到一个子节点中
               if (attribute) {
                   cout.width(depth);
                   while (attribute) {
                       string attributeName = attribute->Name();
                       string attributeValue = attribute->Value();
                       if (attributeName == "score") {
                           label = attributeValue;
                       }
                       attribute = attribute->Next();
                   }
               }

               node = node->FirstChild();
               XMLElement *element = node->ToElement();
               NodeName = element->Name();
               const XMLAttribute *attribute = element->FirstAttribute();
               //接下来是节点不等式的转换
               if (NodeName == "SimplePredicate") {
                   map<string, string> map_attribute;
                   if (attribute) {
                       cout.width(depth);
                       while (attribute) {
                           string name = attribute->Name();
                           string values = attribute->Value();
                           values = values.substr(values.find("(") + 1, values.find(")") - values.find("(") - 1);
                           map_attribute.insert(pair<string, string>(name, values));
                           attribute = attribute->Next();
                       }
                   }
                   //节点规则匹配
                   string field = map_attribute.at("field");
                   string rule = map_attribute.at("operator");
                   string rule_value = map_attribute.at("value");
                   if (rule == "lessOrEqual") {
                       float feature_value = input.at(field);
                       if (feature_value <= int2str(rule_value)) {
                           cout << label << "";

                       } else {
                           //跳到右节点
                           node = node->Parent();
                           node = node->NextSibling();
                       }
                   }
               }else if (NodeName == "True") {
                   cout << label <<endl;
               }
           }
               //将所有属性聚合,输入到一个子节点中
           else if (NodeName == "Node" and !attribute) {
               node = node->FirstChild();
               XMLElement *element = node->ToElement();
               NodeName = element->Name();
               const XMLAttribute *attribute = element->FirstAttribute();

               if (NodeName == "SimplePredicate") {
                   map<string, string> map_attribute;
                   if (attribute) {
                       cout.width(depth);
                       while (attribute) {
                           string name = attribute->Name();
                           string values = attribute->Value();
                           values = values.substr(values.find("(") + 1, values.find(")") - values.find("(") - 1);
                           map_attribute.insert(pair<string, string>(name, values));
                           attribute = attribute->Next();
                       }
                   }
                   //节点规则匹配
                   string field = map_attribute.at("field");
                   string rule = map_attribute.at("operator");
                   string rule_value = map_attribute.at("value");
                   if (rule == "lessOrEqual") {
                       float feature_value = input.at(field);
                       if (feature_value <= int2str(rule_value)) {
                           node->FirstChild();
                       } else {
                           node = node->NextSibling();
                       }
                   }
               }
           }
       }

       // 当前节点node的第一个子节点

       if (node->FirstChild()) {
           depth += 10;
           if(label ==""){
           ParserXMLFile(depth, node->FirstChild(), input);
           }

       }
       //遍历完之后指向下一个节点
       node = node->NextSibling();
   }

       if(depth>0){
        depth-=10;
    }
}

int main(){

    XMLDocument document;
    XMLError xmlError;
//    XMLError errXml = document.Parse(pXml);
    xmlError = document.LoadFile("/Users/hzp/Desktop/DecisionTreeIris.pmml");
    if(document.FirstChild()->ToDeclaration()){
//        cout<<document.FirstChild()->ToDeclaration()->Value()<<endl;
    }
    if(document.FirstChild()->NextSibling()->ToComment()){
//        cout<<document.FirstChild()->NextSibling()->ToComment()->Value()<<endl;
    }
//    cout<<"Root Element: "<< document.RootElement()->Name()<<endl;
    int depth=0;

//    parseSimple();

//5.1	3.5	1.4	0.2  --label 0
// 6.3	2.5	5	1.9 --- label 2
//5.9	3.2	4.8	1.8  -- label 1
    map<string, float > input;
    input.insert(pair<string, float >("x1",5.9));
    input.insert(pair<string, float >("x2",3.2));
    input.insert(pair<string, float >("x3",4.8));
    input.insert(pair<string, float >("x4",1.8));
    ParserXMLFile(depth,document.RootElement()->LastChild()->LastChild()->FirstChild(),input);
    return 0;
}
//void parseSimple(){
//    try {
//        Simplexml* simplexml;
//        simplexml=new Simplexml(xmlpathError);
//
//        simplexml->next("DataDictionary");//将头指针向子节点移动
//
//        cout<<simplexml->child("DataField",0)->child("Value",0)->attr("value")<<endl;
//
//    }catch (string e){
//        cout<<e<<endl;
//    }
//}

转载请注明出处

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值