python代码
from tensorflow.keras import models
import numpy as np
import pickle
import pandas as pd
from flask import Flask
import time;
app = Flask(__name__)
def p_color(df):
line = df['color']
if not pd.isnull(line):
return line
return '_'
def p_clarity(df):
line = df['clarity']
if not pd.isnull(line):
return line
return '_'
def split_features(X):
X_list = []
carat = X[..., [0]] #取第一列 ,得到arry数组 (n,1)
X_list.append(carat)
color = X[..., [1]]
X_list.append(color)
clarity = X[..., [2]]
X_list.append(clarity)
depth = X[..., [3]]
X_list.append(depth)
table = X[..., [4]]
X_list.append(table)
price = X[..., [5]]
X_list.append(price)
x = X[..., [6]]
X_list.append(x)
y = X[..., [7]]
X_list.append(y)
z = X[..., [8]]
X_list.append(z)
return X_list
model = models.load_model('C:\\Users\\张家豪\\Desktop\\keras\\sf.h5')
@app.route("/predict1/<data>")
def predict_res(data):
start = time.clock()
df = pd.read_csv(r'C:\Users\张家豪\Desktop\keras\diamonds4.csv')
df = df[['carat','cut','color','clarity','depth','table','price','x','y','z']]
allcols = list(df)
continuouscols = ['carat','depth','table','price','x','y','z']
collist = ['color','clarity','cut']
for col in collist:
df[col] = df[col].astype(str)
f_open = open(r'C:\Users\张家豪\Desktop\keras\dicts.pkl','rb')
dict_f = pickle.load(f_open)
for col in collist: #字典转换
df[col] = df[col].map(dict_f[col])
df['color'] = df.apply(p_color,axis='columns')
df['clarity'] = df.apply(p_clarity,axis='columns')
df.color.replace('_',dict_f['color']['_'],inplace=True)
df.clarity.replace('_',dict_f['clarity']['_'],inplace=True)
x=df[['carat','color','clarity','depth','table','price','x','y','z']]
x=np.array(x)
x=split_features(x)
res = model.predict(x).argmax(1)
end = time.clock()
t = end - start
return str(t)
app.run()
java代码
package com.shangguigu.helloworld;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URL;
import java.net.URLConnection;
import java.util.List;
import java.util.Map;
public class PythonFlask {
public static String sendGet(String strUrl, String requestParams) {
String responseParams = "";
BufferedReader bufferedReader = null;
try {
String strRequestUrl = strUrl + "/" + requestParams;
URL url = new URL(strRequestUrl);
URLConnection urlConnection = url.openConnection(); // 打开与 URL 之间的连接
// 设置通用的请求属性
urlConnection.setRequestProperty("accept", "*/*");
urlConnection.setRequestProperty("connection", "Keep-Alive");
urlConnection.setRequestProperty("user-agent",
"Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1;SV1)");
urlConnection.connect(); // 建立连接
Map<String, List<String>> map = urlConnection.getHeaderFields(); // 获取所有响应头字段
// 使用BufferedReader输入流来读取URL的响应
bufferedReader = new BufferedReader(new InputStreamReader(
urlConnection.getInputStream()));
String strLine;
while ((strLine = bufferedReader.readLine()) != null) {
responseParams += strLine;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (bufferedReader != null) {
bufferedReader.close();
}
} catch (Exception e2) {
e2.printStackTrace();
}
}
return responseParams;
}
public static void main(String[] args) {
// carat, color, clarity, depth, table, price, x, y, z
String param = "0.65,5,1,62.0,58.0,1838,5.50,5.58,3.43";
String res = sendGet("http://localhost:5000/predict1",param);
System.out.println(res);
}
}
先启动python,出现下面这个图,代表java可以访问
我现在的是返回预测时间,记得改一下,返回预测结果
出现这个图代表已经访问