对手写数字数据集进行随机森林集成分析
import matplotlib
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
'''导入手写数字数据集'''
from sklearn.datasets import fetch_openml
mnist = fetch_openml("mnist_784")
# print(mnist)
rf_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1)
rf_clf.fit(mnist["data"], mnist["target"])
'''绘制热力图'''
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap=matplotlib.cm.hot)
plt.axis('off')
plot_digit(rf_clf.feature_importances_)
char = plt.colorbar(ticks=[rf_clf.feature_importances_.min(), rf_clf.feature_importances_.max()])
char.ax.set_yticklabels(["Not important", "Very important"])
plt.show()
最终得到的结果: