创建xgboost的模型,训练后保存为pmml文件,这个都没有什么问题的,网上也有很多资源,其实我这个也是抄来的(小声),
import pandas
from xgboost.sklearn import XGBClassifier,XGBRegressor
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn.preprocessing import LabelEncoder
iris_df = pandas.read_csv("iris.csv")
iris_df['species'] = LabelEncoder().fit_transform(iris_df['species'].values)
#iris_df.columns
#['sepal_length', 'sepal_width', 'petal_length', 'petal_width','species']
clf = XGBClassifier(
silent=0 ,#设置成1则没有运行信息输出,最好是设置为0.是否在运行升级时打印消息。
#nthread=4,# cpu 线程数 默认最大
learning_rate= 0.3, # 如同学习率
min_child_weight=1,
# 这个参数默认是 1,是每个叶子里面 h 的和至少是多少,对正负样本不均衡时的 0-1 分类而言
#,假设 h 在 0.01 附近,min_child_weight 为 1 意味着叶子节点中最少需要包含 100 个样本。
#这个参数非常影响结果,控制叶子节点中二阶导的和的最小值,该参数值越小,越容易 overfitting。
max_depth=6, # 构建树的深度,越大越容易过拟合
gamma=0, # 树的叶子节点上作进一步分区所需的最小损失减少,越大越保守,一般0.1、0.2这样子。
subsample=1, # 随机采样训练样本 训练实例的子采样比
max_delta_step=0,#最大增量步长,我们允许每个树的权重估计。
colsample_bytree=1, # 生成树时进行的列采样
reg_lambda=1, # 控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。
objective= 'multi:softmax', #多分类的问题 指定学习任务和相应的学习目标
n_estimators=100, #树的个数
seed=1000
)
pipeline = PMMLPipeline([("classifier", clf)])
pipeline.fit(iris_df[iris_df.columns.difference(["species"])],iris_df["species"])
sklearn2pmml(pipeline,"xgboost.pmml",with_repr = True)
紧接着我们再创建java项目来调用pmml,要想java调用pmml,需要引入相应的jar包
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.13</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.13</version>
</dependency>
然后就可以了
package com.test.pmml;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluatorFactory;
/**
* Hello world!
*
*/
class PMMLDemo {
private Evaluator loadPmml() {
PMML pmml = new PMML();
InputStream inputStream = null;
try {
inputStream = new FileInputStream("xgboost.pmml");
} catch (IOException e) {
e.printStackTrace();
}
if (inputStream == null) {
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException e1) {
e1.printStackTrace();
} catch (JAXBException e1) {
e1.printStackTrace();
} finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
pmml = null;
return evaluator;
}
private int predict(Evaluator evaluator,Map<String, Double> featuremap) {
List<InputField> inputFields = evaluator.getInputFields();
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = featuremap.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
int primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
System.out.println(computable.getResult());
primitiveValue = (Integer)computable.getResult();
}
return primitiveValue;
}
public static void main(String args[]){
PMMLDemo demo = new PMMLDemo();
Evaluator model = demo.loadPmml();
Map<String, Double> data = new HashMap<String, Double>();
//这里的key一定要对应python中的列名,一开始我在网上找的例子是随便起的名字,不管输入什么数据返回结果都是0
data.put("sepal_length", 5.1);
data.put("sepal_width", 3.5);
data.put("petal_length", 1.4);
data.put("petal_width", 0.3);
demo.predict(model,data);
}
}