sklearn库中的随机森林学习记录
决策树分类器,分类0-9这十个数字
import scipy.io as scio
from sklearn import tree
import graphviz
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
data_features = scio.loadmat('mnist_train.mat')['mnist_train'][:2500]
data_labels = scio.loadmat('mnist_train_labels.mat')['mnist_train_labels'][:2500]
data_labels = data_labels.ravel()
train_features, test_features, train_labels, test_labels = train_test_split(data_features, data_labels, test_size=0.2, random_state=42)
print(train_features.shape)
print(test_features.shape)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(train_features, train_labels)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("digit_tree")
acc = []
acc1 = []
for i in range(1,31):
rf = RandomForestClassifier(n_estimators=i*5, criterion="entropy")
rf.fit(train_features, train_labels)
predictions = rf.predict(test_features)
n = 0
for k in range(len(predictions)):
if predictions[k] == test_labels[k]:
n += 1
acc.append(n/len(predictions))
print(acc)
for i in range(1,31):
rf1 = RandomForestClassifier(n_estimators=i*5, criterion="gini")
rf1.fit(train_features, train_labels)
predictions1 = rf1.predict(test_features)
n = 0
for k in range(len(predictions1)):
if predictions1[k] == test_labels[k]:
n += 1
acc1.append(n/len(predictions1))
print(acc1)
plt.figure()
plt.plot([i*5 for i in range(1, 31)], acc, label='entropy')
plt.plot([i*5 for i in range(1, 31)], acc1, 'm.-.', label='gini')
plt.legend()
plt.show()
其中的数据集来自:https://github.com/zhouweixin/bayes-mnist