- 通过sklearn建立Logistic回归模型,并绘制混淆矩阵和ROC曲线
#删选
data=data[['年龄','文化','婚姻','居住','吸毒年限','社交行为得分',"社会支持度得分","是否第一次进入戒毒所","是否复吸"]]
# 查看数据集的基本信息
print(data.info())
print(data.describe())
# 分割数据集,
X=data[['年龄','文化','婚姻','居住','吸毒年限','社交行为得分',"社会支持度得分"]] #特征
X=X.fillna(X.mean()) #用均值代替
y=data["是否复吸"].astype("int")
X_train,X_test=X[0:527],X[528:]
y_train,y_test =y[0:527],y[528:]
# 建模
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
print(log_reg)
log_reg.fit(X_train,y_train)
from sklearn.metrics import accuracy_score
y_predict_log = log_reg.predict(X_test)
# 调用accuracy_score计算分类准确度
acc_score=accuracy_score(y_test,y_predict_log)
print(acc_score)
# #绘制混淆矩阵
from sklearn.metrics import confusion_matrix
cnf_matrix = confusion_matrix(y_test,y_predict_log)
def plot_cnf_matirx(cnf_matrix, description):
class_names = [0, 1]
fig, ax = plt.subplots()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names)
plt.yticks(tick_marks, class_names)
# create a heat map
sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap='OrRd',fmt='g')
ax.xaxis.set_label_position('top')
plt.tight_layout()
plt.title(description, y=1.1, fontsize=12)
plt.ylabel('实际值0/1', fontsize=12)
plt.xlabel('预测值0/1', fontsize=12)
plt.show()
plot_cnf_matirx(cnf_matrix, 'Confusion matrix -- Logistic Regression')
decision_scores = log_reg.decision_function(X_test)
# 绘制ROC曲线
from sklearn.metrics import roc_curve
fprs,tprs,thresholds = roc_curve(y_test,decision_scores)
def plot_roc_curve(fprs, tprs):
plt.figure(figsize=(8, 6), dpi=80)
plt.plot(fprs, tprs)
plt.plot([0, 1], linestyle='--')
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.ylabel('TP rate', fontsize=15)
plt.xlabel('FP rate', fontsize=15)
plt.title('ROC曲线', fontsize=17)
plt.show()
plot_roc_curve(fprs, tprs)
# 求面积,相当于求得分
from sklearn.metrics import roc_auc_score #auc:area under curve
roc_auc_score=roc_auc_score(y_test,decision_scores)
print(roc_auc_score)
建立好的模型正确率为72%左右
绘制出ROC曲线,AUC值为0.85,模型预测效果比较好,有实际运用价值。
- 将建立好的模型导出,用的是PMML,用之前需要导入依赖
建立一个Maven项目,导入依赖:
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>1.8.0</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.5</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.14</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>jaxb-runtime</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>com.jayway.jsonpath</groupId>
<artifactId>json-path</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.45</version>
</dependency>
</dependencies>
从Python里导出PMML文件
#删选
data=data[['age','edu','marriage','live','drugyear','score1',"score2","first","relapse"]]
# 查看数据集的基本信息
# 分割数据集,
X=data[['age','edu','marriage','live','drugyear','score1',"score2"]] #特征
X=X.fillna(X.mean()) #用均值代替
y=data["relapse"].astype("int")
X_train,X_test=X[0:527],X[528:]
y_train,y_test =y[0:527],y[528:]
# 建模
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
pipeline=PMMLPipeline([('lr',lr)])
pipeline.fit(X=X_train, y=y_train.values.ravel())
pipeline.verify(X_train.sample(n = 15, random_state=100))
# 导出pmml文件
sklearn2pmml(pipeline, r'logistic.pmml', with_repr=True)
3.为了通过Java调用PMML文件,我们先建立对应的数据库。
先定义一个JDBCUtils工具类,方便连接数据库。
public class JDBCUtils {
private static String url;
private static String user;
private static String password;
private static String driver;
/**
* 文件读取,只会执行一次,使用静态代码块
*/
static {
//读取文件,获取值
try {
//1.创建Properties集合类
Properties pro = new Properties();
//获取src路径下的文件--->ClassLoader类加载器
ClassLoader classLoader = JDBCUtils.class.getClassLoader();
URL resource = classLoader.getResource("jdbc.properties");;
String path = resource.getPath();
//2.加载文件
pro.load(new FileReader(path));
//3获取数据
url = pro.getProperty("url");
user = pro.getProperty("user");
password = pro.getProperty("password");
driver = pro.getProperty("driver");
//4.注册驱动
Class.forName(driver);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
/**
* 获取连接
* @return 连接对象
*/
public static Connection getConnection() throws SQLException {
Connection conn = DriverManager.getConnection(url, user, password);
return conn;
}
/**
* 释放资源
* @param rs
* @param st
* @param conn
*/
public static void close(ResultSet rs, Statement st, Connection conn){
if (rs != null){
try {
rs.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
if(st != null){
try {
st.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
if (conn != null){
try {
conn.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
}
根据所建立的数据库编写一个实体类:
public class user {
private double age;
private double edu;
private double marri;
private double live;
private double drugyear;
private double score1;
private double score2;
public user(double age, double edu, double marri, double live, double drugyear, double score1, double score2) {
this.age = age;
this.edu = edu;
this.marri = marri;
this.live = live;
this.drugyear = drugyear;
this.score1 = score1;
this.score2 = score2;
}
@Override
public String toString() {
return "user{" +
"age=" + age +
", edu=" + edu +
", marri=" + marri +
", live=" + live +
", drugyear=" + drugyear +
", score1=" + score1 +
", score2=" + score2 +
'}';
}
}
接下来就是关键的代码了,调用PMML文件:
首先把你导出的PMML文件复制到Java项目里,然后导入。
//导入我们训练好的模型,logistic.pmml
String pathxml = "logistic.pmml";
通过Map这个数据结构输入到模型里预测。
public static void predictLrHeart(Map<String, Object> map, String pathxml) throws Exception {
PMML pmml;
// 模型导入
File file = new File(pathxml);
InputStream inputStream = new FileInputStream(file);
try (InputStream is = inputStream) {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
.newInstance();
ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
.newModelEvaluator(pmml);
Evaluator evaluator = (Evaluator) modelEvaluator;
List<InputField> inputFields = evaluator.getInputFields();
// map
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = map
.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
//对于分类问题等有多个输出。
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.err.println("预测值输出: " + targetFieldName.getValue()
+ " 复吸可能性: " + targetFieldValue);
}
}
}
因为上面封装的函数的数据结构是Map,Key是字段,value是字段值,我们必须把从数据库里查出来的结果集也放到一个Map里,写一个封装结果集为Map的函数。
public static List<Map<String, Object>> getSimpleObjs(String sql, Double... params) {
Connection conn = null;
Statement st = null;
ResultSet rs = null;
PreparedStatement ps = null;
List<Map<String, Object>> mapList = new ArrayList<Map<String, Object>>();
try {
conn = com.example.util.JDBCUtils.getConnection();
ps = conn.prepareStatement(sql);
if (params != null) {
for (int i = 0; i < params.length; i++) {
ps.setObject(i + 1, params[i]);
}
}
rs = ps.executeQuery();
ResultSetMetaData rsmd = rs.getMetaData();
//类ResultSet有getMetaData()会返回数据的列和对应的值的信息,
// 然后我们将列名和对应的值作为map的键值存入map对象之中...
while (rs.next()) {
Map<String, Object> map = new HashMap<String, Object>();
for (int i = 0; i < rsmd.getColumnCount(); i++) {
String col_name = rsmd.getColumnName(i + 1);
Object col_value = rs.getObject(col_name);
if (col_value == null) {
col_value = "";
}
map.put(col_name, (Double) col_value);
}
mapList.add(map);
}
return mapList;
} catch (SQLException e) {
e.printStackTrace();
return null;
} finally {
com.example.util.JDBCUtils.close(rs, ps, conn);
}
}
最后,在主函数里执行查询语句,并调用predicHeart的函数进行结果的输出。
public static void main(String[] args) throws Exception {
//导入我们训练好的模型,logistic.pmml
String pathxml = "logistic.pmml";
//将查询到的所有结果输入到结果集map里
String sql="SELECT * FROM data";
List<Map<String,Object>> userList=getSimpleObjs(sql);
for(int i=0;i<userList.size();i++){
Map<String,Object> usermap=userList.get(i);
for(String key:usermap.keySet()){
//输出map里的值
System.out.println(key+","+usermap.get(key));
}
//预测
predictLrHeart(usermap,pathxml);
}
}
结果:可以看到,在我们输入的信息里,第一组值,被划分到了0类,就是不会复吸的人群,后面的1=0.133…这个值就是他会复吸的概率是这个值,0=0.866…说明他不会复吸的概率是86%,因为大于了0.5,就把它归到0类里去,下面一组也是一样。