随机森林对标准手写数据集分类

随机森林对标准手写数据集分类

介绍

随机森林,指的是利用多棵树对样本进行训练并预测的一种分类器,将许多棵决策树整合成森林,并合起来用来预测最终结果,可以用来做分类、回归等问题。大多数情况下效果远要比SVM,log回归,KNN等算法效果好。

随机森林的构建过程

1.从原始训练集中随机有放回采样选出m个样本,共进行N次采样,生成N个训练集
2.对于N个训练集,我们分别训练N个决策树模型
3.对于单个决策树模型,假设训练样本特征的个数为n,那么每次分裂时根据信息增益/信息增益比/基尼指数选择最好的特征进行分裂
4.每棵树都一直这样分裂下去,直到该节点的所有训练样例都属于同一类。在决策树的分裂过程中不需要剪枝
5.将生成的多棵决策树组成随机森林。对于分类问题,按多棵树分类器投票决定最终分类结果,对于回归问题,由多棵树预测值的均值决定最终预测结果

优点、缺点

具有极高的准确率
随机性的引入,使得随机森林不容易过拟合
随机性的引入,使得随机森林有很好的抗噪声能力
能处理很高维度的数据,并且不用做特征选择
既能处理离散型数据,也能处理连续型数据,数据集无需规范化
训练速度快,可以得到变量重要性排序
容易实现并行化
当随机森林中的决策树个数很多时,训练时需要的空间和时间会较大
随机森林模型还有许多不好解释的地方,有点算个黑盒模型

from sklearn.ensemble import RandomForestClassifier
data=[[0,0,0],[1,1,1],[2,2,2],[1,1,1],[2,2,2],[3,3,3],[1,1,1],[4,4,4]]
target=[0,1,2,1,2,3,1,4]
rf = RandomForestClassifier()
rf.fit(data,target)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)
print(rf.predict_proba([[1,1,1]]))
#[[ 0.  1.  0.  0.  0.]]
#输出是5个数,因为target有5个值
import matplotlib.pyplot as plt
import matplotlib.pyplot
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
digits = load_digits()
fig = plt.figure(figsize=(6,6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(64):
    ax = fig.add_subplot(8,8,i+1, xticks=[], yticks=[])
    ax.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')
    ax.text(0,7,str(digits.target[i]))
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, random_state=0)
model = RandomForestClassifier(n_estimators=1000)
model.fit(x_train, y_train)
ypre = model.predict(x_test)
print(metrics.classification_report(ypre, y_test))
mat = confusion_matrix(y_test, ypre)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False)
plt.xlabel('true label')
plt.ylabel('predicted label')
matplotlib.pyplot.show()

结果截图
手写数据集的部分样本和热力图

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值