算法简述
ExtraTrees(极度随机树),与随机森林(Random Forest)是一样的,都是决策树的集成模型,区别在于:分叉的方式
- 随机森林依据Gini或信息熵
- ExtraTrees是随机,没错纯随机,随机的特征构建边,随机的阈值来分叉
在筛选特征时也可以使用随机森林,但是在容易过拟合的情况下纯随机的ExtraTrees表现会更好,因为无形中表示:在随机的场景下(模拟未知分布的测试集)某些特征仍旧表示出强势的区分度,证明这个特征很重要
示例代码
import pandas as pd
def load_data():
"""用来生成训练、测试数据"""
from sklearn.datasets import make_classification
data_x, data_y = make_classification(n_samples=1000, n_classes=4, n_features=10, n_informative=8)
df_x = pd.DataFrame(data_x, columns=['f_1', 'f_2', 'f_3', 'f_4', 'f_5', 'f_6', "f_7", "f_8", "f_9", "f_10"])
df_y = pd.Series(data_y)
return df_x, df_y
def select_from_model(x_data, y_data):
from sklearn.feature_selection import SelectFromModel
from sklearn.ensemble import ExtraTreesClassifier
# 使用ExtraTrees作为特征筛选的依据
sf_model: SelectFromModel = SelectFromModel(ExtraTreesClassifier())
sf_model.fit(x_data, y_data)
print("建议保留的特征: ", x_data.columns[sf_model.get_support()])
print("特征重要性:", sf_model.estimator_.feature_importances_)
# sf_model.threshold_
# sf_model.get_support() # get_support函数来得到到底是那几列被选中了
return sf_model.transform(x_data) # 得到筛选的特征
if __name__ == '__main__':
value_x, value_y = load_data()
select_from_model(value_x, value_y) # 带特征的筛选x_data,y_data
参考文章
- RandomForestClassifier vs ExtraTreesClassifier in scikit learn:https://stackoverflow.com/questions/22409855/randomforestclassifier-vs-extratreesclassifier-in-scikit-learn
- Gini Impurity细节与案例:https://bambielli.com/til/2017-10-29-gini-impurity/
- SelectFromModel的文档:https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFromModel.html
- ExtraTreesClassifier的文档:https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html