Java调用pyhton训练的机器学习模型

python广泛用于机器学习训练模型,java又被大量开发者所使用的,因此存在跨语言调用的问题。幸好有pmml的出现,将python模型直接保存为“.pmml”结尾的文件使其调用。最简单的流程基本走通。存在的问题:如何将PMML模型文件用于AS中,使模型能够用于App的使用,目前还在寻找,太难了。。。。。

python模型直接保存为pmml: 使用的为sklearn2pmml

from sklearn.ensemble import RandomForestRegressor
from sklearn import model_selection,metrics,cross_validation
import numpy as np
import xlrd
import csv
import pandas as pd
from openpyxl.workbook import Workbook
from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
import time
from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

#读取excle
def readexcle(filename):
    fh=xlrd.open_workbook(filename)
    table=fh.sheets()[0]
    rows=table.nrows
    exdata=[]
    for row in range(rows):          #获取行数
        data=table.row_values(row) #读取每行的数值
        exdata.append(data)
    return exdata

#分离数据以及label
def splitnumandlab(data):
    culm=len(data[0][:])
    row=len(data)
    
    
    yslabel=[]
    ysnumdata=[]
    for j in range(row):
        yslabel.append(data[j][culm-1])
        
        ysnumdata.append(data[j][0:culm-12])
        print(len(data[j][0:culm-12]))
    return yslabel,ysnumdata
#获取excle原始数据
y_s_data=readexcle(r"C:\Users\Rui Kong\Desktop\ceshi.xlsx")

label1,numdata1=splitnumandlab(y_s_data) 
x_num=np.array(numdata1)
y_lab=np.array(label1)
x_train,x_test,y_train,y_test=model_selection.train_test_split(x_num,y_lab,train_size=0.85)
clf=RandomForestRegressor(n_estimators=20,max_depth=200)
pipline=PMMLPipeline([("classifier",clf)])
pipline.fit(x_train,y_train)
#pipline.score(x_train,y_train)
#pipline.score(x_test,y_test)
sklearn2pmml(pipline,r"C:\Users\Rui Kong\Desktop\ceshi.pmml")

Eclipse中调用,需要导入的Jar包为:pmml-evaluator-example-executable-1.4.13.jar

package pmml;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.xml.bind.JAXBException;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
import org.xml.sax.SAXException;

public class main_enter {
	public static void main(String args[]) {
		Evaluator evaluator = loadPmml();
		ArrayList<Float> arraylist = new ArrayList<>();
		arraylist.add(10.1f);
		arraylist.add(158.1f);
		arraylist.add(1009.1f);
		arraylist.add(1800.1f);
		arraylist.add(158.1f);
		Object ab = predict(evaluator,arraylist);
		System.out.println(ab);
	}
	public static  Evaluator loadPmml(){
        PMML pmml  = new PMML();
        InputStream inputStream = null;
        File file = new File("C:\\Users\\Rui Kong\\Desktop\\ceshi.pmml");
        try{
            inputStream = new FileInputStream(file);//在非activity类中使用getResource需要传入context
           
        }catch (Exception e){
            e.printStackTrace();
        }
        if(inputStream==null){
            return null;
        }
        try{
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
        } catch (JAXBException e1) {
            e1.printStackTrace();
        } catch (SAXException e2) {
            e2.printStackTrace();
        }
        try{
            inputStream.close();
        }catch (IOException e){
            e.printStackTrace();
            }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        return evaluator;
    }
    public static Object predict(Evaluator evaluator, ArrayList<Float>a){
        int m=a.size();
        HashMap<String, Float>map = new HashMap<>();
        //将数组储存在map中
        for (int i =1;i<m+1;i++){
            map.put("x"+i,a.get(i-1));
        }
        System.out.println(map);
        List<InputField> inputFields = evaluator.getInputFields();
        //从画像中获取数据,作为模型的输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
        for (InputField inputField :inputFields){
            FieldName inputFieldName = inputField.getName();
            Object rawValue = map.get(inputFieldName.getValue());
            FieldValue inputValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName,inputValue);
        }
        System.out.println(arguments);
        Map<FieldName,?> results =evaluator.evaluate(arguments);//模型识别结果文件
        System.out.println(results);
        List<TargetField> targetFields = evaluator.getTargetFields();
        TargetField targetField = targetFields.get(0);//返回第一个数作为预测结果,预测只有一个结果,分类则有几个结果
        Object targetValue = results.get(targetField);

        return targetValue;
    }
}

反正最后能跑通,给出了一个预测结果:
结果示意图
[github上面,evaluator的jar包]https://github.com/jpmml/jpmml-evaluator/releases,也可以好好看看大佬写的example

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值