一、准备:
第三方库 sklearn
二、代码:
# -*- coding: utf-8 -*-
# @Time : 2018/8/21 9:35
# @Author : Barry
# @File : mnist.py
# @Software: PyCharm Community Edition
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import tensorflow.examples.tutorials.mnist.input_data as input_data
data_dir = 'MNIST_data/'
mnist = input_data.read_data_sets(data_dir,one_hot=False)
batch_size = 50000
batch_x,batch_y = mnist.train.next_batch(batch_size)
test_x = mnist.test.images[:10000]
test_y = mnist.test.labels[:10000]
print("start random forest")
for i in range(10,200,10):
clf_rf = RandomForestClassifier(n_estimators=i)
clf_rf.fit(batch_x,batch_y)
y_pred_rf = clf_rf.predict(test_x)
acc_rf = accuracy_score(test_y,y_pred_rf)
print("n_estimators = %d, random forest accuracy:%f" %(i,acc_rf))