概要
书接上文,在简单的跑通鸢尾花示例代码之后要拿自己的模型下手了,在经历一番艰难调整之后总算跑通了第一个无监督学习模型。本例采用sklearn2pmml生成的pmml模型,模型由pmml.pipeline流水线构成,包含PCA方法和KMEANS方法,通过java的JPMML库进行调用,预测新的成员属于哪个群体。
代码
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import java.io.*;
import java.util.*;
public class TestPmml {
public static void main(String args[]) throws Exception {
String fp = "edu.pmml";
TestPmml obj = new TestPmml();
Evaluator model = new LoadingModelEvaluatorBuilder()
.load(new File("edu.pmml"))
.build(); #模型初始化
Map<FieldName, Object> arguments = new HashMap<>(); #格式化输入
arguments.put(FieldName.create("Aggression"), 35.6058);
arguments.put(FieldName.create("Stress"), 24.193);
arguments.put(FieldName.create("Tension"), 47.7854);
arguments.put(FieldName.create("Suspect"), 36.18);
arguments.put(FieldName.create("Balance"), 77.5673);
arguments.put(FieldName.create("Charm"), 78.1811);
arguments.put(FieldName.create("Energy"), 25.8678);
arguments.put(FieldName.create("Self-Regulation"), 77.715);
arguments.put(FieldName.create("Inhibition"), 25.4405);
arguments.put(FieldName.create("Neuroticism"), 34.7914);
arguments.put(FieldName.create("Extraversion"), 0.619428);
arguments.put(FieldName.create("Stability"), 0.509011);
arguments.put(FieldName.create("Depression"), 20.4566);
arguments.put(FieldName.create("Happiness"), 37.0281);
arguments.put(FieldName.create("FH_M"), 1.98672);
arguments.put(FieldName.create("FH_S"), 0.630348);
#模型评估
Map<FieldName, ?> results = model.evaluate(arguments);
List<TargetField> targetFields = model.getTargetFields();
#获取结果
for (TargetField targetField : targetFields) {
Object targetValue = results.get(targetField.getName());
System.out.println("target=" + EvaluatorUtil.decode(targetValue));
// 回归和分类模型的结果区域
if (targetValue instanceof HasReport) {
HasReport hasReport = (HasReport) targetValue;
Report report = hasReport.getReport();
if (report != null) {
System.out.println("target=" + ReportUtil.format(report));
}
} // End if
// 概率分类模型的结果区域
if (targetValue instanceof HasProbability hasProbability) {
Set<Object> targetCategories = hasProbability.getCategories();
for (Object targetCategory : targetCategories) {
Double probability = hasProbability.getProbability(targetCategory);
System.out.println("probability(" + targetCategory + ")=" + probability);
Report probabilityReport = hasProbability.getProbabilityReport(targetCategory);
if (probabilityReport != null) {
System.out.println("probability(" + targetCategory + ")=" + ReportUtil.format(probabilityReport));
}
}
}
}
}
}
本例重点在于格式化的输入和输出,JPMML规定了一个标准化的输入方法,即Map<FieldName, Object> arguments。标准化的结果获取,即evaluate方法和getTargetFields方法。