建立Logistic回归模型预测吸毒人员的复吸概率,导出模型到Java中调用

  1. 通过sklearn建立Logistic回归模型,并绘制混淆矩阵和ROC曲线

#删选
data=data[['年龄','文化','婚姻','居住','吸毒年限','社交行为得分',"社会支持度得分","是否第一次进入戒毒所","是否复吸"]]
# 查看数据集的基本信息
print(data.info())
print(data.describe())
# 分割数据集,
X=data[['年龄','文化','婚姻','居住','吸毒年限','社交行为得分',"社会支持度得分"]] #特征
X=X.fillna(X.mean()) #用均值代替
y=data["是否复吸"].astype("int")
X_train,X_test=X[0:527],X[528:]
y_train,y_test =y[0:527],y[528:]

# 建模
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
print(log_reg)
log_reg.fit(X_train,y_train)



from sklearn.metrics import accuracy_score
y_predict_log = log_reg.predict(X_test)

# 调用accuracy_score计算分类准确度
acc_score=accuracy_score(y_test,y_predict_log)
print(acc_score)

# #绘制混淆矩阵
from sklearn.metrics import confusion_matrix
cnf_matrix = confusion_matrix(y_test,y_predict_log)
def plot_cnf_matirx(cnf_matrix, description):
    class_names = [0, 1]
    fig, ax = plt.subplots()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names)
    plt.yticks(tick_marks, class_names)
    # create a heat map
    sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap='OrRd',fmt='g')
    ax.xaxis.set_label_position('top')
    plt.tight_layout()
    plt.title(description, y=1.1, fontsize=12)
    plt.ylabel('实际值0/1', fontsize=12)
    plt.xlabel('预测值0/1', fontsize=12)
    plt.show()
plot_cnf_matirx(cnf_matrix, 'Confusion matrix -- Logistic Regression')
decision_scores = log_reg.decision_function(X_test)

# 绘制ROC曲线
from sklearn.metrics import roc_curve
fprs,tprs,thresholds = roc_curve(y_test,decision_scores)
def plot_roc_curve(fprs, tprs):
    plt.figure(figsize=(8, 6), dpi=80)
    plt.plot(fprs, tprs)
    plt.plot([0, 1], linestyle='--')
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.ylabel('TP rate', fontsize=15)
    plt.xlabel('FP rate', fontsize=15)
    plt.title('ROC曲线', fontsize=17)
    plt.show()

plot_roc_curve(fprs, tprs)
# 求面积,相当于求得分
from sklearn.metrics import roc_auc_score  #auc:area under curve
roc_auc_score=roc_auc_score(y_test,decision_scores)
print(roc_auc_score)

正确
建立好的模型正确率为72%左右

roc曲线
绘制出ROC曲线,AUC值为0.85,模型预测效果比较好,有实际运用价值。

  1. 将建立好的模型导出,用的是PMML,用之前需要导入依赖
    建立一个Maven项目,导入依赖:
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>1.8.0</scope>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.5</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.14</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>2.0.1</version>
        </dependency>
        <dependency>
            <groupId>org.glassfish.jaxb</groupId>
            <artifactId>jaxb-runtime</artifactId>
            <version>2.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.jayway.jsonpath</groupId>
            <artifactId>json-path</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>5.1.45</version>
        </dependency>
    </dependencies>

从Python里导出PMML文件

#删选
data=data[['age','edu','marriage','live','drugyear','score1',"score2","first","relapse"]]
# 查看数据集的基本信息
# 分割数据集,
X=data[['age','edu','marriage','live','drugyear','score1',"score2"]] #特征
X=X.fillna(X.mean()) #用均值代替
y=data["relapse"].astype("int")
X_train,X_test=X[0:527],X[528:]
y_train,y_test =y[0:527],y[528:]

# 建模
from sklearn.linear_model import LogisticRegression

lr = LogisticRegression()
pipeline=PMMLPipeline([('lr',lr)])
pipeline.fit(X=X_train, y=y_train.values.ravel())
pipeline.verify(X_train.sample(n = 15, random_state=100))

# 导出pmml文件
sklearn2pmml(pipeline, r'logistic.pmml', with_repr=True)



3.为了通过Java调用PMML文件,我们先建立对应的数据库。
基本流程
先定义一个JDBCUtils工具类,方便连接数据库。

public class JDBCUtils {
    private static String url;
    private static String user;
    private static String password;
    private  static String driver;
    /**
     * 文件读取,只会执行一次,使用静态代码块
     */
    static {
        //读取文件,获取值
        try {
            //1.创建Properties集合类
            Properties pro = new Properties();
            //获取src路径下的文件--->ClassLoader类加载器
            ClassLoader classLoader = JDBCUtils.class.getClassLoader();
            URL resource = classLoader.getResource("jdbc.properties");;
            String path = resource.getPath();
            //2.加载文件
            pro.load(new FileReader(path));
            //3获取数据
            url = pro.getProperty("url");
            user = pro.getProperty("user");
            password = pro.getProperty("password");
            driver = pro.getProperty("driver");
            //4.注册驱动
            Class.forName(driver);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
    }
    /**
     * 获取连接
     * @return 连接对象
     */
    public static Connection getConnection() throws SQLException {
        Connection conn = DriverManager.getConnection(url, user, password);
        return conn;
    }
    /**
     * 释放资源
     * @param rs
     * @param st
     * @param conn
     */
    public static void close(ResultSet rs, Statement st, Connection conn){
        if (rs != null){
            try {
                rs.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if(st != null){
            try {
                st.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        if (conn != null){
            try {
                conn.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }
}

根据所建立的数据库编写一个实体类:

public class user {
    private double age;
    private double edu;
    private double marri;
    private double live;
    private double drugyear;
    private double score1;
    private double score2;

    public user(double age, double edu, double marri, double live, double drugyear, double score1, double score2) {
        this.age = age;
        this.edu = edu;
        this.marri = marri;
        this.live = live;
        this.drugyear = drugyear;
        this.score1 = score1;
        this.score2 = score2;
    }

    @Override
    public String toString() {
        return "user{" +
                "age=" + age +
                ", edu=" + edu +
                ", marri=" + marri +
                ", live=" + live +
                ", drugyear=" + drugyear +
                ", score1=" + score1 +
                ", score2=" + score2 +
                '}';
    }
}


接下来就是关键的代码了,调用PMML文件:
首先把你导出的PMML文件复制到Java项目里,然后导入。

      //导入我们训练好的模型,logistic.pmml
        String pathxml = "logistic.pmml";

通过Map这个数据结构输入到模型里预测。

 public static void predictLrHeart(Map<String, Object> map, String pathxml) throws Exception {

        PMML pmml;
        // 模型导入
        File file = new File(pathxml);
        InputStream inputStream = new FileInputStream(file);
        try (InputStream is = inputStream) {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);

            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
                    .newInstance();
            ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
                    .newModelEvaluator(pmml);
            Evaluator evaluator = (Evaluator) modelEvaluator;

            List<InputField> inputFields = evaluator.getInputFields();
            // map
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getName();
                Object rawValue = map
                        .get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(rawValue);
                arguments.put(inputFieldName, inputFieldValue);
            }

            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            List<TargetField> targetFields = evaluator.getTargetFields();
            //对于分类问题等有多个输出。
            for (TargetField targetField : targetFields) {
                FieldName targetFieldName = targetField.getName();
                Object targetFieldValue = results.get(targetFieldName);
                System.err.println("预测值输出: " + targetFieldName.getValue()
                        + " 复吸可能性: " + targetFieldValue);
            }
        }
    }

因为上面封装的函数的数据结构是Map,Key是字段,value是字段值,我们必须把从数据库里查出来的结果集也放到一个Map里,写一个封装结果集为Map的函数。

 public static List<Map<String, Object>> getSimpleObjs(String sql, Double... params) {
        Connection conn = null;
        Statement st = null;
        ResultSet rs = null;
        PreparedStatement ps = null;
        List<Map<String, Object>> mapList = new ArrayList<Map<String, Object>>();
        try {
            conn = com.example.util.JDBCUtils.getConnection();
            ps = conn.prepareStatement(sql);
            if (params != null) {
                for (int i = 0; i < params.length; i++) {
                    ps.setObject(i + 1, params[i]);
                }
            }
            rs = ps.executeQuery();
            ResultSetMetaData rsmd = rs.getMetaData();
            //类ResultSet有getMetaData()会返回数据的列和对应的值的信息,
            // 然后我们将列名和对应的值作为map的键值存入map对象之中...
            while (rs.next()) {
                Map<String, Object> map = new HashMap<String, Object>();
                for (int i = 0; i < rsmd.getColumnCount(); i++) {
                    String col_name = rsmd.getColumnName(i + 1);
                    Object col_value = rs.getObject(col_name);
                    if (col_value == null) {
                        col_value = "";
                    }
                    map.put(col_name, (Double) col_value);
                }
                mapList.add(map);
            }
            return mapList;
        } catch (SQLException e) {
            e.printStackTrace();
            return null;
        } finally {
            com.example.util.JDBCUtils.close(rs, ps, conn);
        }


    }

最后,在主函数里执行查询语句,并调用predicHeart的函数进行结果的输出。

 public static void main(String[] args) throws Exception {

        //导入我们训练好的模型,logistic.pmml
        String pathxml = "logistic.pmml";
        //将查询到的所有结果输入到结果集map里
       String sql="SELECT * FROM data";
 
       List<Map<String,Object>> userList=getSimpleObjs(sql);

       for(int i=0;i<userList.size();i++){
           Map<String,Object> usermap=userList.get(i);
           for(String key:usermap.keySet()){
           //输出map里的值
               System.out.println(key+","+usermap.get(key));
           }
           //预测
           predictLrHeart(usermap,pathxml);
       }
    }

结果:可以看到,在我们输入的信息里,第一组值,被划分到了0类,就是不会复吸的人群,后面的1=0.133…这个值就是他会复吸的概率是这个值,0=0.866…说明他不会复吸的概率是86%,因为大于了0.5,就把它归到0类里去,下面一组也是一样。
预测结果

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值