# (1) 导入所需的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import accuracy_score
import graphviz
# (2) 读取数据文件并进行数据清洗
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')
# 统计缺失值并处理
print("Train Data Missing Values:")
print(train_data.isnull().sum())
print("\nTest Data Missing Values:")
print(test_data.isnull().sum())
# 处理缺失值
train_data = train_data.dropna()
test_data = test_data.dropna()
# (3) 将非数字的类别特征数字化
le = preprocessing.LabelEncoder()
train_data['Sex'] = le.fit_transform(train_data['Sex'])
test_data['Sex'] = le.transform(test_data['Sex'])
# (4) 构造决策树模型,训练并评估模型
X = train_data[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']]
y = train_data['Survived']
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# 构建决策树模型
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)
# 评估模型
y_pred = clf.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
print("Decision Tree Accuracy:", accuracy)
# (5) 对测试数据集进行预测
test_X = test_data[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']]
test_data['Survived'] = clf.predict(test_X)
# (6) 决策树的可视化
dot_data = export_graphviz(clf, out_file=None,
feature_names=X.columns,
class_names=['Not Survived', 'Survived'],
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('titanic_decision_tree', format='png')
# 生成客舱等级的生存率图
survival_by_pclass = train_data.groupby('Pclass')['Survived'].mean()
survival_by_pclass.plot(kind='bar', title='Survival Rate by Pclass')
plt.xlabel('Pclass')
plt.ylabel('Survival Rate')
plt.show()
# 年龄段的生存率图
age_bins = [0, 18, 30, 50, 100]
age_labels = ['0-18', '18-30', '30-50', '50+']
train_data['AgeGroup'] = pd.cut(train_data['Age'], bins=age_bins, labels=age_labels)
survival_by_age = train_data.groupby('AgeGroup')['Survived'].mean()
survival_by_age.plot(kind='bar', title='Survival Rate by Age Group')
plt.xlabel('Age Group')
plt.ylabel('Survival Rate')
plt.show()
# 父母子女数与生存率图
survival_by_parch = train_data.groupby('Parch')['Survived'].mean()
survival_by_parch.plot(kind='bar', title='Survival Rate by Parch')
plt.xlabel('Parch')
plt.ylabel('Survival Rate')
plt.show()
# 兄弟姐妹配偶与生存率图
survival_by_sibsp = train_data.groupby('SibSp')['Survived'].mean()
survival_by_sibsp.plot(kind='bar', title='Survival Rate by SibSp')
plt.xlabel('SibSp')
plt.ylabel('Survival Rate')
plt.show()