今日的示例代码包含2个部分
- notebook文件夹内的ipynb文件,介绍下今天的思路
- 项目文件夹中其他部分:拆分后的信贷项目,学习下如何拆分的,未来你看到的很多大项目都是类似的拆分方法
知识点回顾
- 规范的文件命名
- 规范的文件夹管理
- 机器学习项目的拆分
- 编码格式和类型注解
作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。
src/data/data_loader.py
import pandas as pd
from sklearn.model_selection import train_test_split
def load_and_split_data(file_path, target_column, test_size=0.2, random_state=42):
"""
加载数据并划分训练集和测试集
"""
data = pd.read_csv(file_path)
X = data.drop(target_column, axis=1)
y = data[target_column]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
return X_train, X_test, y_train, y_test
src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import time
def train_random_forest(X_train, y_train, X_test, y_test, random_state=42):
"""
训练随机森林模型并评估性能
"""
start_time = time.time()
model = RandomForestClassifier(random_state=random_state)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
end_time = time.time()
print(f"训练与预测耗时: {end_time - start_time:.4f} 秒")
print("\n默认随机森林 在测试集上的分类报告:")
print(classification_report(y_test, y_pred))
print("默认随机森林 在测试集上的混淆矩阵:")
print(confusion_matrix(y_test, y_pred))
return model
src/utils/visualization.py
import shap
import matplotlib.pyplot as plt
def plot_shap_values(model, X_test):
"""
绘制SHAP值的条形图、蜂巢图和依赖图
"""
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
print("shap_values[0] shape:", shap_values[0].shape)
print("X_test shape:", X_test.shape)
# SHAP特征重要性条形图
print("--- 1. SHAP 特征重要性条形图 ---")
shap.summary_plot(shap_values[0], X_test, plot_type="bar", show=False)
plt.title("SHAP Feature Importance (Bar Plot)")
plt.show()
# SHAP特征重要性蜂巢图
print("--- 2. SHAP 特征重要性蜂巢图 ---")
shap.summary_plot(shap_values[0], X_test, plot_type="violin", show=False, max_display=10)
plt.title("SHAP Feature Importance (Violin Plot)")
plt.show()
# SHAP特征重要性依赖图
print("--- 3. SHAP 特征重要性依赖图 ---")
shap.dependence_plot('Years in current job', shap_values[0], X_test, show=False)
plt.title("SHAP Feature Importance (dependence plot)")
plt.show()
src/main.py
from src.data.data_loader import load_and_split_data
from src.models.random_forest import train_random_forest
from src.utils.visualization import plot_shap_values
if __name__ == "__main__":
# 数据加载与划分
file_path = "data/raw/heart.csv"
target_column = "target"
X_train, X_test, y_train, y_test = load_and_split_data(file_path, target_column)
# 模型训练与评估
model = train_random_forest(X_train, y_train, X_test, y_test)
# SHAP值可视化
plot_shap_values(model, X_test)