五、模型构建
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
import joblib
1. 读取数据
# 读取数据
Xtrain = pd.read_excel('../tmp/sj_final.xlsx')
ytrain = pd.read_excel('../data/water_heater_log.xlsx')
test = pd.read_excel('../data/test_data.xlsx')
2. 训练集测试集区分
# 训练集测试集区分。
x_train, x_test, y_train, y_test = Xtrain.iloc[:,5:],test.iloc[:,4:-1],\
ytrain.iloc[:,-1],test.iloc[:,-1]
# 标准化
stdScaler = StandardScaler().fit(x_train)
x_stdtrain = stdScaler.transform(x_train)
x_stdtest = stdScaler.transform(x_test)
3. 建立模型
# 建立模型
bpnn = MLPClassifier(hidden_layer_sizes=(15, 12),
max_iter=200,
solver='lbfgs',
random_state=50)
bpnn.fit(x_stdtrain, y_train)
# 保存模型
joblib.dump(bpnn, '../tmp/water_heater_nnet.m')
print('构建的模型为:\n', bpnn)
构建的模型为:
MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
beta_2=0.999, early_stopping=False, epsilon=1e-08,
hidden_layer_sizes=(15, 12), learning_rate='constant',
learning_rate_init=0.001, max_fun=15000, max_iter=200,
momentum=0.9, n_iter_no_change=10, nesterovs_momentum=True,
power_t=0.5, random_state=50, shuffle=True, solver='lbfgs',
tol=0.0001, validation_fraction=0.1, verbose=False,
warm_start=False)
六、模型检验
1. 检验
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve
import joblib
import matplotlib.pyplot as plt
bpnn = joblib.load('../tmp/water_heater_nnet.m') # 加载模型
y_pred = bpnn.predict(x_stdtest) # 返回预测结果
print('神经网络预测结果评价报告:\n', classification_report(y_test, y_pred))
神经网络预测结果评价报告:
precision recall f1-score support
0 0.45 0.42 0.43 12
1 0.82 0.84 0.83 37
accuracy 0.73 49
macro avg 0.64 0.63 0.63 49
weighted avg 0.73 0.73 0.73 49
2. 可视化
# 绘制roc曲线图
plt.rcParams['font.sans-serif'] = 'SimHei' # 显示中文
plt.rcParams['axes.unicode_minus'] = False # 显示负号
fpr, tpr, thresholds = roc_curve(y_pred,y_test) # 求出TPR和FPR
plt.figure(figsize=(6,4)) # 创建画布
plt.plot(fpr,tpr) # 绘制曲线
plt.title('用户用水事件识别ROC曲线') # 标题
plt.xlabel('FPR') # x轴标签
plt.ylabel('TPR') # y轴标签
plt.savefig('../tmp/用户用水事件识别ROC曲线.png') # 保存图片
plt.show() # 显示图形
数据及代码:链接:https://pan.baidu.com/s/1u4k2oYId4Bc9E5YupbpybA
提取码:3rwm