基于MySQL训练某机器学习算法模型与Java跨语言

 简介:通过将存储于MySQL中的数据提取出来,调用机器学习中的某个算法训练出模型并保存,再使用Java跨语言调用训练好的模型,然后可视化。训练流程如图1.1中所示。主要涉及技术及工具有Sprint Boot2 + Vue2 + Echarts + Anaconda + IDEA + Noetpad++等;需要准备:

1、环境准备。 Java + Maven + MySQL + nvm + Anaconda

2、数据准备。 该项目中使用地震数据。

图1.1 

1、模型训练

1.1 数据获取及导入MySQL

git clone git@gitee.com:luoyanUFO/my-sql-data.git

1.2 启动Anaconda (安装省略)

在桌面新建任意一个文件夹,点击进入,输入cmd.如图1.2、1.3、1.4所示,然后输入命令(如果没有下载先下载jupyter notebook),一段时间后自动跳转浏览器,如果没有自动跳转复制url在浏览器中打开。

注意:CTRL + enter 运行代码(光标要停在要运行的那个框中); shift + enter 新建一个运行框

jupyter notebook

 图 1.2   

图 1.3 

图 1.4 

1.3 安装pymysql

如果两个命令都安装不成功,再去网上找一下其他方法。如果不行可以去官网下载对应的轮子,然后导入到本地。过程省略。

# 安装, 请勿重复运行这一栏
# conda install pymysql # 先用这个,不行然后注释用下一个
# !pip install pymysql
#  pip install --user -i https://pypi.tuna.tsinghua.edu.cn/simple sklearn2pmml  # 安装sklearn2pmml, 如果后面导入出错再安装

1.4 提取MySQL中的数据

'''
 随机森林预测
    --------  通过经纬度训练模型-------------------
'''


#  导入所需包
import pymysql
import pandas as pd 
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from sklearn2pmml import PMMLPipeline, sklearn2pmml


#连接数据库
url = '192.168.6.166' # 连接地址
user = 'root' # 用户名
pws = '123456' # 密码
db = 'rwly' # 数据库名称
conn=pymysql.connect(host = url # 连接名称,默认127.0.0.1
,user = user       # 用户名
,passwd = pws     # 密码
,port = 3306          # 端口,默认为3306
,db = db           # 数据库名称
,charset='utf8'      # 字符编码
)
cur = conn.cursor()   # 生成游标对象
sql="SELECT `longitude`,`dim`,`km`,`ml` FROM ads_serism_all;"    # SQL语句-- 注意查询表
cur.execute(sql)    # 执行SQL语句
data = cur.fetchall()   # 通过fetchall方法获得数据

# for i in data[:21]:     # 打印输出前2条数据
#     print (i)
df = pd.DataFrame(data, columns=['经度', '维度', '震深', '震级'])
print("-----------------------------------查询数据展示------------------------------------------------------------\n")
print(df)
# print(classification_report(data, target_names=target_names))

cur.close()   # 关闭游标
conn.close()  # 关闭连接

运行结果如图1.5。

图1.5                                

 1.5 训练模型

        1.5.1 震级模型

'''
 随机森林预测
    -----------------------------------随机森林预测 -- 震级模型-----------------------------------------------
'''

# 特征提取:选取与薪资相关的特征
X = df[['经度', '维度']]  # 用身高和年龄来预测得分情况
y = df['震级']
#
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 随机森林模型建立
rf = RandomForestRegressor(n_estimators=100, random_state=123)
#
# 模型训练
rf.fit(X_train, y_train)
# 模型评估
y_pred = rf.predict(X_test)
print("------------------------震级模型评估-----------------------------------\n")
# 平均绝对误差
mae = mean_absolute_error(y_test, y_pred)
print('震级模型评估',y_pred)
print("--------------------------震级平均绝对误差(MAE)---------------------------------\n")
print('震级平均绝对误差(MAE):', mae)

# 导出模型到 ml.pmml 文件
sklearn2pmml(rf, "ml.pmml", with_repr = True)

1.5.2 震深模型

'''
 随机森林预测
    -----------------------------------随机森林预测 -- 震深模型-----------------------------------------------
'''
#
# 特征提取:选取与薪资相关的特征
X = df[['经度', '维度']]  # 用身高和年龄来预测得分情况
y = df['震深']
#
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 随机森林模型建立
rf2 = RandomForestRegressor(n_estimators=100, random_state=123)
#
# 模型训练
rf2.fit(X_train, y_train)
# 模型评估
y_pred = rf2.predict(X_test)
print("------------------------震深模型评估-----------------------------------\n")
# 平均绝对误差
mae = mean_absolute_error(y_test, y_pred)
print('震深模型评估',y_pred)
print("--------------------------震深平均绝对误差(MAE)---------------------------------\n")
print('震深平均绝对误差(MAE):', mae)

# 导出模型到 km.pmml 文件
sklearn2pmml(rf2, "km.pmml", with_repr = True)

1.6 模型预览

模型保存在新建的那个文件夹中,如图1.6所示;使用idea打开,可以修改。

图1.6   

2、调用模型   

创建 spring boot 工程省略  。。。                                                          

 2.1 调入jar包

<!--    pmml    -->
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.1</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.1</version>
        </dependency>

2.2 模型位置

将模型存放到static目录下,如图1.7所示

图 1.7

2.3 调用方法

在util目录下编写方法

package com.itly.rwly.utils;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;

/**
 * Created with IntelliJ IDEA.
 *
 * @Author: 程序员CK君
 * @Version: 1.0
 * @Date: 2024-03-17-20:53
 * @Description: 模型调用工具类
 */
@SuppressWarnings({"all"})
public class UserDefinedModel {

    private Evaluator modelEvaluator;

    /**
     * 通过传入 PMML 文件路径来生成机器学习模型
     *
     * @param pmmlFileName pmml 文件路径
     */
    public UserDefinedModel(String pmmlFileName) {
        PMML pmml = null;

        try {
            if (pmmlFileName != null) {
                InputStream is = new FileInputStream(pmmlFileName);
                pmml = PMMLUtil.unmarshal(is);
                try {
                    is.close();
                } catch (IOException e) {
                    System.out.println("InputStream close error!");
                }

                ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();

                this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
                modelEvaluator.verify();
                System.out.println("加载模型成功!");
            }
        } catch (SAXException e) {
            e.printStackTrace();
        } catch (JAXBException e) {
            e.printStackTrace();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }

    }
    // 获取模型需要的特征名称
    public List<String> getFeatureNames() {
        List<String> featureNames = new ArrayList<String>();

        List<InputField> inputFields = modelEvaluator.getInputFields();

        for (InputField inputField : inputFields) {
            featureNames.add(inputField.getName().toString());
        }
        return featureNames;
    }

    // 获取目标字段名称
    public String getTargetName() {
        return modelEvaluator.getTargetFields().get(0).getName().toString();
    }
    //
    public String predict(String a, String b) {
        Map<String, String> data = new HashMap<>();
        data.put("经度", a);
        data.put("维度", b);
        List<InputField> inputFields = modelEvaluator.getInputFields();
        // 过模型的原始特征,参数获取,作为模型输入
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField inputField : inputFields) {
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue());
            FieldValue inputFieldValue = inputField.prepare(rawValue);
            arguments.put(inputFieldName, inputFieldValue);
        }

        Map<FieldName, ?> results = modelEvaluator.evaluate(arguments);
        List<TargetField> targetFields = modelEvaluator.getTargetFields();

        TargetField targetField = targetFields.get(0);
        FieldName targetFieldName = targetField.getName();

        Object targetFieldValue = results.get(targetFieldName);
        System.out.println("target: " + targetFieldName.getValue() + " 预测结果: " + targetFieldValue);
//        int primitiveValue = -1;
//        if (targetFieldValue instanceof Computable) {
//            Computable computable = (Computable) targetFieldValue;
//            primitiveValue = (Integer)computable.getResult();
//        }
//        System.out.println(a + " " + b + ":" + primitiveValue);
        return targetFieldValue + "";
    }
}

2.4 实现调用

在controller目录下编写一个类,如下代码,

package com.itly.rwly.controller;

import com.itly.rwly.pojo.Code;
import com.itly.rwly.pojo.RandomForestParam;
import com.itly.rwly.pojo.Result;
import com.itly.rwly.utils.UserDefinedModel;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;

/**
 * Created with IntelliJ IDEA.
 *
 * @Author: 程序员CK君
 * @Version: 1.0
 * @Date: 2024-03-17-22:17
 * @Description: 随机森林预测 --根据经纬度预测地级 和 震深
 */
@Controller
@RequestMapping("/rfs")
@SuppressWarnings("all")
public class RandomForestController {

    @PostMapping("/data")
    @ResponseBody
    public Result getData3(@RequestBody RandomForestParam param) {
        // 震级模型存放目录
        UserDefinedModel clf = new UserDefinedModel("src/main/resources/static/ml.pmml");
        String ml = clf.predict(param.getLongitude(), param.getDim());
        // 震深模型存放目录
        UserDefinedModel clf2 = new UserDefinedModel("src/main/resources/static/km.pmml");
        String km = clf2.predict(param.getLongitude(), param.getDim());
        // 加入数据
        String data = ml + "-" + km;
        return new Result(data != null ? Code.GET_OK : Code.GET_ERR, data);
    }

}

2.5 输入参数验证

使用Postman验证,如图1.8所示.。。。可视化省略

图1.8

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值