【随机森林学习记录-手写数字识别分类】

sklearn库中的随机森林学习记录

决策树分类器,分类0-9这十个数字

import scipy.io as scio
from sklearn import tree
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

data_features = scio.loadmat('mnist_train.mat')['mnist_train'][:2500]
data_labels = scio.loadmat('mnist_train_labels.mat')['mnist_train_labels'][:2500]
data_labels = data_labels.ravel()
train_features, test_features, train_labels, test_labels = train_test_split(data_features, data_labels, test_size=0.2, random_state=42)

print(train_features.shape)
print(test_features.shape)

clf = tree.DecisionTreeClassifier()
clf = clf.fit(train_features, train_labels)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("digit_tree")

acc = []
acc1 = []
for i in range(1,31):
    rf = RandomForestClassifier(n_estimators=i*5, criterion="entropy")
    rf.fit(train_features, train_labels)
    predictions = rf.predict(test_features)
    n = 0
    for k in range(len(predictions)):
        if predictions[k] == test_labels[k]:
            n += 1
    acc.append(n/len(predictions))
print(acc)
for i in range(1,31):
    rf1 = RandomForestClassifier(n_estimators=i*5, criterion="gini")
    rf1.fit(train_features, train_labels)
    predictions1 = rf1.predict(test_features)
    n = 0
    for k in range(len(predictions1)):
        if predictions1[k] == test_labels[k]:
            n += 1
    acc1.append(n/len(predictions1))
print(acc1)

plt.figure()
plt.plot([i*5 for i in range(1, 31)], acc, label='entropy')
plt.plot([i*5 for i in range(1, 31)], acc1, 'm.-.', label='gini')
plt.legend()
plt.show()

在这里插入图片描述

其中的数据集来自:https://github.com/zhouweixin/bayes-mnist

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值