前言
- 最近公司有个需求,需要对用户进行数据画像分析。
- 公司大数据组通过对线上用户数据进行分析后,通过机器学习用python做了一个训练模型pkl文件包。
- 要求我部门对用户数据进行分析计算。而我部门的项目都是使用Java进行开发的,所以就需要Java调用pkl训练模型包。
- 经过调研python的pkl训练模型包不能直接被Java调用,跨平台调用需要使用pmml格式文件,所以就让大数据部门依照已经生成的训练模型pkl文件,在次封装成一个pmml文件。
pmml格式
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
<Header>
<Application name="JPMML-SkLearn" version="1.6.27"/>
<Timestamp>2021-08-30T06:48:45Z</Timestamp>
</Header>
<DataDictionary>
<DataField name="y" optype="categorical" dataType="integer">
<Value value="0"/>
<Value value="1"/>
</DataField>
<DataField name="x1" optype="continuous" dataType="double"/>
<DataField name="x2" optype="continuous" dataType="double"/>
<DataField name="x3" optype="continuous" dataType="double"/>
</DataDictionary>
<RegressionModel functionName="classification" algorithmName="sklearn.linear_model._logistic.LogisticRegression" normalizationMethod="logit">
<MiningSchema>
<MiningField name="y" usageType="target"/>
<MiningField name="x1"/>
<MiningField name="x2"/>
<MiningField name="x3"/>
</MiningSchema>
<RegressionTable intercept="0.5920457931585216" targetCategory="1">
<NumericPredictor name="x1" coefficient="0.7586778342148665"/>
<NumericPredictor name="x2" coefficient="0.6562980822443883"/>
<NumericPredictor name="x3" coefficient="0.9917332587791079"/>
</RegressionTable>
<RegressionTable intercept="0.0" targetCategory="0"/>
</RegressionModel>
</PMML>
Java调用pmml文件
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>
- Java调用方法
- 当有test.pmml文件后,可以把文件放在springboot项目的resources目录下,使用ClassPathResource类获取到文件流
/**
* @Author: ZRH
* @Date: 2021/8/30 9:17
*/
@Slf4j
public final class ClassificationModelOld {
private static Evaluator modelEvaluator;
static {
PMML pmml;
try {
Resource resource = new ClassPathResource("test.pmml");
InputStream is = resource.getInputStream();
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info("InputStream close error!");
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
modelEvaluator.verify();
log.info("加载模型成功!");
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 私有化构造函数,防止外部创建实例
*/
private ClassificationModelOld () {
}
/**
* 获取模型需要的特征名称
*
* @return
*/
public static List<String> getFeatureNames () {
List<String> featureNames = new ArrayList<>();
List<InputField> inputFields = modelEvaluator.getInputFields();
for (InputField inputField : inputFields) {
featureNames.add(inputField.getName().toString());
}
return featureNames;
}
/**
* 获取目标字段名称
*
* @return
*/
public static String getTargetName () {
return modelEvaluator.getTargetFields().get(0).getName().toString();
}
/**
* 使用模型生成概率分布
*
* @param arguments
* @return
*/
private static ProbabilityDistribution getProbabilityDistribution (Map<FieldName, ?> arguments) {
Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);
FieldName fieldName = FieldName.create(getTargetName());
return (ProbabilityDistribution) evaluateResult.get(fieldName);
}
/**
* 预测不同分类的概率
*
* @param arguments
* @return
*/
public static ValueMap<String, Number> predictProba (Map<FieldName, Number> arguments) {
ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
return probabilityDistribution.getValues();
}
/**
* 预测结果分类
*
* @param arguments
* @return
*/
public static Object predict (Map<FieldName, ?> arguments) {
ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
return probabilityDistribution.getPrediction();
}
private static Integer setScore (float probability) {
int score = 0;
try {
// TODO 根据比例写算法计算出分值
score = 520;
} catch (Exception e) {
}
return score;
}
public static void main (String[] args) {
// 参数进过转义后:{{"value":"x1"}:-0.216918810277242,{"value":"x2"}:0.0583184157700168,{"value":"x3"}:-0.653728631926331}
final ArrayList<Double> doubles = Lists.newArrayList(-0.216918810277242, 0.0583184157700168, -0.653728631926331);
Map<FieldName, Number> waitPreSample = new HashMap<>(8);
waitPreSample.put(FieldName.create("x1"), doubles.get(0));
waitPreSample.put(FieldName.create("x2"), doubles.get(1));
waitPreSample.put(FieldName.create("x3"), doubles.get(2));
final ValueMap<String, Number> values = ClassificationModelOld.predictProba(waitPreSample);
System.out.println("机器算法计算分值结果:" + setScore(values.get("1").floatValue()));
}
}
---------------------
执行结果:
加载模型成功!
机器算法计算分值结果:520
版本问题
- 上面示例是使用的老版本的包,并且打的pmml文件也是4.3版本的
- 所以如果使用的是4.4版本的pmml文件
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.5.11</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.5.11</version>
</dependency>
static {
PMML pmml;
try {
Resource resource = new ClassPathResource("test.pmml");
InputStream is = resource.getInputStream();
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info("InputStream close error!");
}
ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
modelEvaluator = modelEvaluatorBuilder.build();
modelEvaluator.verify();
log.info("加载模型成功!");
} catch (Exception e) {
e.printStackTrace();
}
}
- 这样4.4版本的pmml训练模型文件也是可以执行获取结果
最后