jav使用python模型实现方案

大体流程

用python将生成的模型导出为pmml格式,java通过加载本地pmml文件的形式读取模型,给予参数,实现预测或分类,本文用spring boot,接口接收前端各参数数值,预测后返回给前端使用,大致这么个流程

python方面

生成模型并导出pmml的代码

def exportPmml():
    iris = load_iris()
    # 创建带有特征名称的 DataFrame
    iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
    # 创建模型管道
    iris_pipeline = PMMLPipeline([
        ("classifier", RandomForestClassifier())
    ])
    # 训练模型
    iris_pipeline.fit(iris_df, iris.target)
    # 导出模型到 RandomForestClassifier_Iris.pmml 文件
    sklearn2pmml(iris_pipeline, "RandomForestClassifier_Iris.pmml")
    return None

网上可以搜到挺多,这里是用的JPMML导出Pmml,pipeline用的是PMMLPipeline,其实也是实现的sklearn的pipeline,如果用Nyoka,pipeline用的是sklearn的pipeline,差不多,它们的github主页都提供了Usage和多种模型案例

Java方面

一个坑

网上搜到的代码在加载pmml模型是如果遇到报错提示Exception in thread "main" java.lang.IllegalArgumentException: http://www.dmg.org/PMML-4_4,尝试把文件里第二行版本号修改为修改版本号

另一个小点

maven里有关pmml的包的版本号1.4.x与1.5.x的代码有差别,导的版本不匹配,可能导的东西都找不到,而且名字也会不一样

代码

@Component
public class LoadModel1_5_x {
    private Evaluator evaluator;
    public String modelUrl="E:\\Softwares\\Learn\\IntelliJ_Projects\\IrisModelTest\\src\\main\\resources\\static\\RandomForestClassifier_Iris.pmml";
    
    public LoadModel1_5_x() {
        PMML pmml;
        // 模型导入
        try {
            if (modelUrl != null) {
                File file = new File(modelUrl);
                InputStream inputStream = new FileInputStream(file);

                InputStream is = inputStream;
                pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
                ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
                //1.5.x版本 使用一个pmml配置了两个不同的evaluators
                ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
                //第一种
                this.evaluator = modelEvaluatorBuilder.build();
                System.out.println("模型build成功!");
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (SAXException e) {
            e.printStackTrace();
        } catch (JAXBException e) {
            e.printStackTrace();
        }
    }

    public Object predict(Map<String,Double> irismap){
        List<InputField> inputFields = evaluator.getInputFields();
        // 过模型的原始特征,从画像中获取数据,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = irismap.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }
        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();
        //对于分类问题等有多个输出。
        for (TargetField targetField : targetFields) {
            FieldName targetFieldName = targetField.getName();
            Object targetFieldValue = results.get(targetFieldName);
            System.err.println("target: " + targetFieldName.getValue()
                    + " value: " + targetFieldValue);
        }
        return results;
    }
  }

然后controller里autowired一个模型,构造一个参数map

Map<String, Double> map = new HashMap<String, Double>();
        map.put("sepal length (cm)", (double) 7);
        map.put("sepal width (cm)", (double) 3.2);
        map.put("petal length (cm)", (double) 4.7);
        map.put("petal width (cm)", (double) 1.4);

使用模型类的predict方法就可以简单实现啦

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值