java调用python模型
PMML格式
使用java自带的Runtime.getRuntime().exec(args);方法直接调用python脚本
PMML格式
1、首先将python代码训练的模型保存为pmml格式,代码如下
model = xgb.XGBClassifier()
from sklearn2pmml import PMMLPipeline
pipeline = PMMLPipeline([("classifier", model)])
pipeline.fit(X_train,y_train)
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "xgb.pmml", with_repr = True)
然后使用java读取pmml文件对数据进行预测,
后来选择使用java调用虚拟机的方式运行python脚本。
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;
/**
* 读取pmml 获取模型
*
* @author liaotuo
*
*/
public class ModelInvoker {
private ModelEvaluator modelEvaluator;
// 通过文件读取模型
public ModelInvoker(String pmmlFileName) {
PMML pmml = null;
InputStream is = null;
try {
if (pmmlFileName != null) {
is = ModelInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName);
pmml = PMMLUtil.unmarshal(is);
}
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
System.out.println("模型读取成功");
}
// 通过输入流读取模型
public ModelInvoker(InputStream is) {
PMML pmml = null;
try {
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
}
public Map<FieldName, ?> invoke(Map<FieldName, Object> paramsMap) {
return this.modelEvaluator.evaluate(paramsMap);
}
}
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.FieldName;
/**
* 使用模型
*
* @author gs
*
*/
public class ModelCalc {
static String pmmlPath = "E:\\workspace\\python\\tydic\\model\\xgb.pmml";
public static void main(String[] args) throws IOException {
String modelArgsFilePath = "E:\\workspace\\python\\tydic\\model\\test\\X_val";
predictFromFile(modelArgsFilePath);
}
/**
* 以文件名的方式读取输入数据进行预测
* @param modelArgsFilePath
* @throws FileNotFoundException
* @throws IOException
*/
public static List<String> predictFromFile(String modelArgsFilePath) throws FileNotFoundException, IOException {
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(pmmlPath));
ModelInvoker invoker = new ModelInvoker(bis);
List<Map<FieldName, Object>> paramList = getDataFromFile(modelArgsFilePath);
List<String> predictResult = new ArrayList<String>();
int lineNum = 0; // 当前处理行数
for (Map<FieldName, Object> param : paramList) {
lineNum++;
System.out.println("======当前行: " + lineNum + "=======");
Map<FieldName, ?> result = invoker.invoke(param);
Set<FieldName> keySet = result.keySet(); // 获取结果的keySet
int i = 0;
for (FieldName fn : keySet) {
String probility1 = result.get(fn).toString(); //预测为1的概率
System.out.println(probility1);
// i++;
// if(i%3==0){
// predictResult.add(probility1);
// }
}
}
return predictResult;
}
/**
* 读取参数文件
*
* @param filePath
* @return
* @throws IOException
*/
private static List<Map<FieldName, Object>> getDataFromFile(String filePath) throws IOException {
BufferedReader br = new BufferedReader(new FileReader(filePath));
String[] nameArr = br.readLine().split(" "); // 读取表头的名字
List<Map<FieldName, Object>> list = new ArrayList();
String paramLine = null; // 一行参数
// 循环读取 每次读取一行数据
while ((paramLine = br.readLine()) != null) {
Map<FieldName, Object> map = new HashMap<FieldName, Object>();
String[] paramLineArr = paramLine.split(" ");
// 一次循环处理一行数据
for (int i = 0; i < paramLineArr.length; i++) {
map.put(new FieldName(nameArr[i]), paramLineArr[i]); // 将表头和值组成map
}
list.add(map); // 加入list中
}
return list;
}
}
使用Runtime.getRuntime().exec(args)
这个主要是java代码的书写
public class PythonDemo {
public static void main(String[] args) {
try {
// 需传入的参数
String host = "localhost";
String port = "3306";
String user = "root";
String passwd = "123456";
String path = "C:/";
String database = "dic_coll_consume";
String start_date = "2017-08-01";
String end_date = "2017-09-01";
args = new String[] { "python", "C:\\model_train.py", host, port,user, passwd, path, database, start_date, end_date };
Process pr = Runtime.getRuntime().exec(args);
print(pr.getInputStream());
print(pr.getErrorStream());
} catch (Exception e) {
e.printStackTrace();
}
}
private static String decodeUnicode(String line) {
String l = null;
try {
l = new String(line.getBytes(), "utf8");
} catch (UnsupportedEncodingException e) {
System.out.println("wrong");
e.printStackTrace();
}
return l;
}
private static void print(InputStream stream){
new Thread(new Runnable() {
public void run() {
try{
BufferedReader in = new BufferedReader(new InputStreamReader(stream));
String line;
while ((line = in.readLine()) != null) {
line = decodeUnicode(line);
System.out.println(line);
}
in.close();
System.out.println("end");
} catch (Exception e) {
e.printStackTrace();
}
}
}).start();
}
}
---------------------
作者:Dreamcatcher5
来源:CSDN
原文:https://blog.csdn.net/qq_31955775/article/details/78329699
版权声明:本文为博主原创文章,转载请附上博文链接!