实现RandomForest 随机森林
基于python的sklearn机器学习 类实现
平台 | ||
---|---|---|
python3.7 | Anaconda | sklearn库及配套库 |
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix # 生成混淆矩阵函数
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib#保存模型
import itertools
class Ctrain_forest:
'''
调用sklearn 实现Random Forest功能:
画混淆矩阵
输入数据实现训练
保存模型到指定位置
调用模型实现预测
'''
def plot_confusion_matrix(self,cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues,path="maxtix"):
"""
画混淆矩阵
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
画图函数 输入:
cm 矩阵
classes 输入str类型
title 名字
cmap [图的颜色设置](https://matplotlib.org/examples/color/colormaps_reference.html)
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
plt.figure(figsize=(11,8))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# plt.gca().set_xticks(tick_marks, minor=True)
# plt.gca().set_yticks(tick_marks, minor=True)
# plt.gca().xaxis.set_ticks_position('none')
# plt.gca().yaxis.set_ticks_position('none')
#plt.grid()
# plt.gcf().subplots_adjust(bottom=0.1)
# plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
#解决中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.savefig(path,dpi=500)
# plt.show()
def train_forest(self,x,y,path):
"""
Random Foeset类
输入:
x、y以实现训练,path是保存训练过程的路径
输出:
clf 模型
matrix 混淆矩阵
dd classifi_report
kappa kappa系数
acc_1 模型精度
"""
X_train,data1x,y_train,data1y = train_test_split(x,y,test_size=0.9,random_state=0)
#寻找最优参数
depth = np.arange(1,25,4)
acc_list = []
for d in depth:
clf =RandomForestClassifier(bootstrap=True, class_weight="balanced", criterion='gini',
max_depth=d*10+1, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=3, min_samples_split=3,
min_weight_fraction_leaf=0.0, n_estimators=140*2+1, n_jobs=-1,
oob_score=False, verbose=0, warm_start=False)
clf.fit(X_train, y_train)
y_pred_rf = clf.predict(data1x)
acc=accuracy_score(data1y, y_pred_rf)
acc_list.append(acc)
print(accuracy_score(data1y, y_pred_rf)) #整体精度
print(cohen_kappa_score(data1y, y_pred_rf)) #Kappa系数
#画图
mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(facecolor='w')
plt.plot(depth, acc_list, 'ro-', lw=1)
plt.xlabel('随机森林决策树数量', fontsize=15)
plt.ylabel('预测精度', fontsize=15)
plt.title('随机森林决策树数量和过拟合', fontsize=18)
plt.grid(True)
plt.savefig(path,dpi=300)
#plt.show()
y_pred_rf = clf.predict(data1x)
print(accuracy_score(data1y, y_pred_rf)) #整体精度
#dist=data1y-y_pred_rf
print(cohen_kappa_score(data1y, y_pred_rf)) #Kappa系数
matrix=confusion_matrix(data1y, y_pred_rf)
kappa=cohen_kappa_score(data1y, y_pred_rf)
dd=classification_report(data1y, y_pred_rf)
acc_1=accuracy_score(data1y, y_pred_rf)
"""
# 特征重要性评定
rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1)
rnd_clf.fit(x, y)
for name, score in zip(x, rnd_clf.feature_importances_):
print(name, score)
"""
return clf,matrix,dd,kappa,acc_1
def save_model(self,clf,src):
"""
保存模型到某处
clf 模型
src 路径
"""
joblib.dump(clf, src)
def get_model_predit(self,data,src):
"""
调用模型实现预测
输入原始数据
src 模型路径
返回预测值
"""
getsavemodel=joblib.load(src)
predity=getsavemodel.predict(pd.DataFrame(data))
return predity
运行结果:![调参图](https://i-blog.csdnimg.cn/blog_migrate/865e4c7825f662d48923547eeaffa1b9.png)