java使用pmml调用sklearn算法模型

2 篇文章 0 订阅
1 篇文章 0 订阅

场景

需要在java调用python sklearn训练评估的模型,本文介绍使用pmml来实现。

生成pmml文件

#引入sklearn2pmml包
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

#使用PMMLPipeline包裹具体评估器
clf = PMMLPipeline([("MLPClassifier", MLPClassifier(hidden_layer_sizes=(25,), random_state=1, max_iter=100, warm_start=True))])
clf.fit(value, label)

#保存模型到指定文件
sklearn2pmml(clf, "MLPClassifier.pmml", with_repr=True)

JAVA调用模型

引用java maven依赖包

        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.5.15</version>
        </dependency>

        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.5.15</version>
        </dependency>

java加载模型并评估

        Map<String, Object> paramData = new HashMap<>();
        paramData.put("x1", 180D);
        paramData.put("x2", 350D);

        FileInputStream inputStream = new FileInputStream("MLPClassifier.pmml");
        //解析pmml文件,实际上是用JAXB做xml的解析
        PMML pmml = PMMLUtil.unmarshal(inputStream);
        //生成评估器
        ModelEvaluator<?> evaluate = new ModelEvaluatorBuilder(pmml).build();

        //构建输入参数
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        List<InputField> inputFields = evaluate.getInputFields();
        for (InputField inputField : inputFields) {            //将参数通过模型对应的名称进行添加
            FieldName inputFieldName = inputField.getName();   //获取模型中的参数名
            Object paramValue = paramData.get(inputFieldName.getValue());   //获取模型参数名对应的参数值
            FieldValue fieldValue = inputField.prepare(paramValue);   //将参数值填入模型中的参数中
            arguments.put(inputFieldName, fieldValue);          //存放在map列表中
        }

        //开始评估
        Map<FieldName, ?> target = evaluate.evaluate(arguments);

        //获取评估结果
        List<TargetField> targetFields = evaluate.getTargetFields();
        Object targetFieldValue = target.get(targetFields.get(0).getFieldName());
        System.out.println("targetFieldValue: " + targetFieldValue);
        System.out.println("target: " + target);

注意事项

1.注意生成模型的版本和java依赖包的版本要匹配,否则java侧会无法解析该pmml模型

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值