In this section about ensemble methods, we discuss boosting with a special focus on its most common implementation, AdaBoost 。
1. Via the base_estimator attribute, we train the AdaBoostClassifier on 500 decision tree stumps
from sklearn.ensemble import AdaBoostClassifier tree = DecisionTreeClassifier(criterion='entropy', max_depth=1, random_state=0) ada = AdaBoostClassifier(base_estimator=tree, n_estimators=500, learning_rate=0.1, random_state=0) tree = tree.fit(X_train, y_train) y_train_pred = tree.predict(X_train) y_test_pred = tree.predict(X_test) tree_train = accuracy_score(y_train, y_train_pred) tree_test = accuracy_score(y_test, y_test_pred) print('Decision tree train/test accuracies %.3f/%.3f' % (tree_train, tree_test))Decision tree train/test accuracies 0.845/0.854
2. As we can see, the decision tree stump seems to overft the training data in contrast with the unpruned decision tree that we saw in the previous section :
ada = ada.fit(X_train, y_train) y_train_pred = ada.predict(X_train) y_test_pred = ada.predict(X_test) ada_train = accuracy_score(y_train, y_train_pred) ada_test = accuracy_score(y_test, y_test_pred) print('AdaBoost train/test accuracies %.3f/%.3f' % (ada_train, ada_test))AdaBoost train/test accuracies 1.000/0.875
3. check what the decision regions look like
x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1 y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1)) f, axarr = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(8, 3)) for idx, clf, tt in zip([0, 1], [tree, ada], ['Decision Tree', 'AdaBoost']): clf.fit(X_train, y_train) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) axarr[idx].contourf(xx, yy, Z, alpha=0.3) axarr[idx].scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], c='blue', marker='^') axarr[idx].scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], c='red', marker='o') axarr[idx].set_title(tt) axarr[0].set_ylabel('Alcohol', fontsize=12) plt.text(10.2, -1.2, s='Hue', ha='center', va='center', fontsize=12) plt.show()
Reference:《Python Machine Learning》