机器学习-java调用机器学习pai的LR模型

背景

最近在使用阿里的机器学习pai进行模型训练,训练出来的模型需要给java进行调用来使用。本博客阐述java调用pmml进行预测的过程。

实战

获取pmml模型

在机器学习pai->模板实验->心脏病预测,可以直接训练并且导出模型,具体步骤参考:https://help.aliyun.com/document_detail/34929.html?spm=a2c4g.11186623.6.676.5394607a96UDkz

java调用pmml

依赖支持库jpmml:https://github.com/jpmml/jpmml-evaluator.git

<!-- https://mvnrepository.com/artifact/org.jpmml/pmml-evaluator -->
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.3.5</version>
        </dependency>

这里使用的是1.3.5,跟阿里的版本保持一致,避免出现版本兼容性的问题。

下载pmml的模型文件:https://download.csdn.net/download/wangjie5540/12101924

java源码

@Log4j2
public class PmmlPredict {

    public static Evaluator evaluator;

    public static void initModel() throws IOException, SAXException, JAXBException {
        File file = new File("/home/wang/Downloads/lr_demo.pmml");
        PMML pmml = null;
        try (InputStream is = new FileInputStream(file)) {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        }
        evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
        evaluator.verify();
    }

    public static Integer predict(JSONObject feature) {
        // 获取模型定义的特征
        List<? extends InputField> inputFields = evaluator.getInputFields();
        log.info("模型的特征是:{}", inputFields);
        // 获取模型定义的目标名称
        List<? extends TargetField> targetFields = evaluator.getTargetFields();
        log.info("目标字段是:{}", targetFields);


        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        for (InputField inputField : inputFields) {
            FieldName inputName = inputField.getName();
            String name = inputName.getValue();
            Object rawValue = feature.get(name);
            FieldValue inputValue = inputField.prepare(rawValue);
            arguments.put(inputName, inputValue);
        }
        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        for (Map.Entry<FieldName, ?> entry : results.entrySet()) {
            log.info(entry.getKey());
            log.info(entry.getValue());
        }
        Map resultRecord = (Map) EvaluatorUtil.decode(results);
        Object res = resultRecord.get(FieldName.create("y"));
        log.info("预测结果:");
        log.info(results);
        log.info(resultRecord);
        log.info(res);
        log.info(evaluator.getActiveFields());
        log.info(evaluator.getTargetFields());
        log.info(evaluator.getSummary());
        log.info(evaluator.getMiningFunction());
        return null;
    }

    public static void main(String[] args) throws JAXBException, SAXException, IOException {
        PmmlPredict.initModel();
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("sex", 1);
        jsonObject.put("address", 1);
        jsonObject.put("famsize", 1);
        jsonObject.put("pstatus", 0);
        jsonObject.put("medu", 0);
        jsonObject.put("fedu", 1);
        jsonObject.put("mjob", 1);
        jsonObject.put("fjob", 0);
        jsonObject.put("guardian", 1);
        jsonObject.put("traveltime", 0);
        jsonObject.put("studytime", 0.3);
        jsonObject.put("failures", 0.3);
        jsonObject.put("schoolsup", 0);
        jsonObject.put("fumsup", 1);
        jsonObject.put("paid", 0);
        jsonObject.put("activities", 0);
        jsonObject.put("higher", 0);
        jsonObject.put("internet", 1);
        jsonObject.put("famrel", 0);
        jsonObject.put("freetime", 0.7);
        jsonObject.put("goout", 0.5);
        jsonObject.put("dalc", 0.7);
        jsonObject.put("walc", 0);
        jsonObject.put("health", 0.5);
        jsonObject.put("absences", 0.8);
//        1,1,0,0,1,1,0,1,0,0.3333333333333333,0.3333333333333333,0,1,0,0,0,1,0,0.75,0.5,0.75,0,0,0.5,0.08,0
        PmmlPredict.predict(jsonObject);
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值