什么是随机森林?
随机森林(Random Forest)是一种集成学习(Ensemble Learning)算法,它将多个决策树(Decision Tree)组合在一起形成一个强大的分类器。随机森林通过随机选择特征、随机选择样本、随机生成决策树来构建模型,以此来降低过拟合风险,并提高模型的泛化能力。
随机森林可以用于分类和回归任务,它的优点是能够有效地处理高维数据和大规模数据集,同时能够避免过拟合问题。在分类任务中,随机森林可以输出每个样本属于每个类别的概率;在回归任务中,随机森林可以输出一个连续的数值作为预测结果。
做分类任务代码示例:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
# 加载Iris数据集
iris = load_iris()
# 将数据集分为特征和目标变量
X = iris.data
y = iris.target
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建随机森林分类器
rf = RandomForestClassifier(n_estimators=100, random_state=42)#100颗决策树
# 训练模型
rf.fit(X_train, y_train)
# 预测测试集
y_pred = rf.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)
RandomForestClassifier
的参数选择会影响模型性能,但并没有一个单一的参数选择策略可以适用于所有场景。
n_estimators
:这个参数指定了森林中决策树的数量。通常,较大的值可以提供更好的性能,但也会增加计算时间和内存使用。通常,建议从 50 到 100 开始,然后根据模型性能进行调整。max_depth
:这个参数限制了每个决策树的最大深度。较大的值可以减少过拟合,但也会影响模型的预测能力。建议从 10 到 20 开始,然后根据模型性能进行调整。min_samples_split
:这个参数指定了在进行分裂时每个节点必须考虑的最小样本数。较小的值可以增加模型的灵活性和减少过拟合的风险,但也会降低模型的准确率。建议从 2 到 5 开始,然后根据模型性能进行调整。min_samples_leaf
:这个参数指定了叶子节点至少应包含的最小样本数。较小的值可以增加模型的灵活性和减少过拟合的风险,但也会降低模型的准确率。建议从 1 到 5 开始,然后根据模型性能进行调整。random_state
:这个参数用于设置随机种子,以便结果可重复。这对于研究和比较不同超参数组合时的性能非常有用。如果没有特殊需求,可以将其设置为 42 或 None。
做回归任务代码示例:
# 导入所需的库
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
# 加载数据集
boston = load_boston()
# 将数据集分为特征和目标变量
X = boston.data
y = boston.target
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建随机森林回归模型
rf = RandomForestRegressor(n_estimators=100, random_state=42)
# 训练模型
rf.fit(X_train, y_train)
# 预测测试集结果并计算均方根误差(RMSE)
y_pred = rf.predict(X_test)
rmse = mean_squared_error(y_test, y_pred, squared=False)
print("随机森林回归模型的均方根误差为:", rmse)