PMML三问
PMML是什么?
预测模型标记语言(Predictive Model Markup Language,PMML)是一种可以呈现预测分析模型的事实标准语言。
PMML价值是什么
实现模型的跨语言部署。举个例子,模型是用python训练但是希望部署到JAVA或者C++环境中,解决方案是将模型以PMML格式文件导出,然后利用其他语言进行解析、部署
PMML如何解析
JAVA语言中有JPMML工具、Python也有相关的工具;唯独C++没有找到,只能手写了。
接下来将会贴代码
解析过程
C++解析PMML文件中模型
本实例中用到的相关技术:tinyxml2(c++中解析XML的一个工具,PMML可以认为是XML中的一种特殊格式),具体使用可自行百度
C++(解析语言)、(JAVA、JPMML工具,用来测试解析输出的结果是否正确)
Python、pydotplus、PMMLPipeline(用Iris数据集训练一颗决策树,然后以PMML文件格式导出至本地)
Python导出PMML文件代码
from sklearn2pmml import PMMLPipeline
from sklearn.datasets import load_iris
from sklearn import tree
import graphviz
iris = load_iris()
clf = tree.DecisionTreeClassifier()
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris.data, iris.target)
clf.fit(iris.data, iris.target)
# 导出为PMML
from sklearn2pmml import sklearn2pmml
#sklearn2pmml(pipeline, "/Users/hzp/Desktop/DecisionTreeIris.pmml", with_repr = True)
import pydotplus
with open('/Users/hzp/Desktop/treeone.dot', 'w') as f:
dot_data = tree.export_graphviz(clf, out_file=None)
f.write(dot_data)
#画图方法2-生成pdf文件
dot_data = tree.export_graphviz(clf, out_file=None,feature_names=clf.feature_importances_,
filled=True, rounded=True, special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
##保存图像到pdf文件
graph.write_pdf("/Users/hzp/Desktop/treetwo.pdf")
JAVA解析PMML文件
参考我另外一篇博客
https://blog.csdn.net/zehui6202/article/details/105074143
伪代码思路
//递归遍历PMML中决策树预规则部分的节点
while(node){
if (Nodename == "Node" and has_attribute){
//满足条件认为是叶子节点
label = score //叶子节点的标签属性
Node -> Node->FirstChildren//指向第一个子节点
if (input_feature 满足条件 or Node == "True"){
return label
}
else{
node -> NextNode // 指向下一个节点(当前是在子节点,是要指向父节点的下一个节点)
}
}
else if (Node == "Node" and has_not_attribute){
node -> Node->FirstChildren
if (input_feature 满足条件){
node -> Firstchildren
}
elif if{
node -> NextNode
}
}
}
C++解析决策树PMML文件
功能:输入一个特征,输出特征所属类别
输入:特征名、特征属性
输入格式:map
输出:特征被分类的label
附:也能解析xml的字符串,代码中有例子
#include <iostream>
#include "tinyxml2/simple.h"
#include <map>
#include<algorithm>
#include<string>
#include <sstream>
using namespace std;
using namespace tinyxml2;
//const char* xmlpathError="/Users/hzp/Downloads/TinyXML2-simple-master/errorExample.xml";
const char* xmlpathError="/Users/hzp/Desktop/DecisionTreeIris.pmml";
void ParserXMLFile();
float int2str(string num){
float res;
stringstream stream(num);
stream>>res;
return res;
}
void ParserXMLFile(int &depth,XMLNode *pNode,map<string,float > input){
// 根节点
XMLNode *node=pNode; //指针指向节点初始位置
//单个节点解析
// 如果节点是元素
string label = "";
while(node) {
if (node->ToElement()) {
XMLElement *element = node->ToElement();
cout.width(depth);
string NodeName = element->Name(); //指向节点名
const XMLAttribute *attribute = element->FirstAttribute(); //获取节点第一个属性
//取标签
if (NodeName == "Node" and attribute) {
//将所有属性聚合,输入到一个子节点中
if (attribute) {
cout.width(depth);
while (attribute) {
string attributeName = attribute->Name();
string attributeValue = attribute->Value();
if (attributeName == "score") {
label = attributeValue;
}
attribute = attribute->Next();
}
}
node = node->FirstChild();
XMLElement *element = node->ToElement();
NodeName = element->Name();
const XMLAttribute *attribute = element->FirstAttribute();
//接下来是节点不等式的转换
if (NodeName == "SimplePredicate") {
map<string, string> map_attribute;
if (attribute) {
cout.width(depth);
while (attribute) {
string name = attribute->Name();
string values = attribute->Value();
values = values.substr(values.find("(") + 1, values.find(")") - values.find("(") - 1);
map_attribute.insert(pair<string, string>(name, values));
attribute = attribute->Next();
}
}
//节点规则匹配
string field = map_attribute.at("field");
string rule = map_attribute.at("operator");
string rule_value = map_attribute.at("value");
if (rule == "lessOrEqual") {
float feature_value = input.at(field);
if (feature_value <= int2str(rule_value)) {
cout << label << "";
} else {
//跳到右节点
node = node->Parent();
node = node->NextSibling();
}
}
}else if (NodeName == "True") {
cout << label <<endl;
}
}
//将所有属性聚合,输入到一个子节点中
else if (NodeName == "Node" and !attribute) {
node = node->FirstChild();
XMLElement *element = node->ToElement();
NodeName = element->Name();
const XMLAttribute *attribute = element->FirstAttribute();
if (NodeName == "SimplePredicate") {
map<string, string> map_attribute;
if (attribute) {
cout.width(depth);
while (attribute) {
string name = attribute->Name();
string values = attribute->Value();
values = values.substr(values.find("(") + 1, values.find(")") - values.find("(") - 1);
map_attribute.insert(pair<string, string>(name, values));
attribute = attribute->Next();
}
}
//节点规则匹配
string field = map_attribute.at("field");
string rule = map_attribute.at("operator");
string rule_value = map_attribute.at("value");
if (rule == "lessOrEqual") {
float feature_value = input.at(field);
if (feature_value <= int2str(rule_value)) {
node->FirstChild();
} else {
node = node->NextSibling();
}
}
}
}
}
// 当前节点node的第一个子节点
if (node->FirstChild()) {
depth += 10;
if(label ==""){
ParserXMLFile(depth, node->FirstChild(), input);
}
}
//遍历完之后指向下一个节点
node = node->NextSibling();
}
if(depth>0){
depth-=10;
}
}
int main(){
XMLDocument document;
XMLError xmlError;
// XMLError errXml = document.Parse(pXml);
xmlError = document.LoadFile("/Users/hzp/Desktop/DecisionTreeIris.pmml");
if(document.FirstChild()->ToDeclaration()){
// cout<<document.FirstChild()->ToDeclaration()->Value()<<endl;
}
if(document.FirstChild()->NextSibling()->ToComment()){
// cout<<document.FirstChild()->NextSibling()->ToComment()->Value()<<endl;
}
// cout<<"Root Element: "<< document.RootElement()->Name()<<endl;
int depth=0;
// parseSimple();
//5.1 3.5 1.4 0.2 --label 0
// 6.3 2.5 5 1.9 --- label 2
//5.9 3.2 4.8 1.8 -- label 1
map<string, float > input;
input.insert(pair<string, float >("x1",5.9));
input.insert(pair<string, float >("x2",3.2));
input.insert(pair<string, float >("x3",4.8));
input.insert(pair<string, float >("x4",1.8));
ParserXMLFile(depth,document.RootElement()->LastChild()->LastChild()->FirstChild(),input);
return 0;
}
//void parseSimple(){
// try {
// Simplexml* simplexml;
// simplexml=new Simplexml(xmlpathError);
//
// simplexml->next("DataDictionary");//将头指针向子节点移动
//
// cout<<simplexml->child("DataField",0)->child("Value",0)->attr("value")<<endl;
//
// }catch (string e){
// cout<<e<<endl;
// }
//}
转载请注明出处