决策树二分类之泰坦尼号克生存预测
一、项目简介
官方链接:Titanic - Machine Learning from Disaster
1.1 项目背景
- 1、泰坦尼克号: 英国白星航运公司下辖的一艘奥林匹克级邮轮,于1909年3月31日在爱尔兰贝尔法斯特港的哈兰德与沃尔夫造船厂动工建造,1911年5月31日下水,1912年4月2日完工试航。
- 2、首航时间: 1912年4月10日
- 3、航线: 从英国南安普敦出发,途经法国瑟堡-奥克特维尔以及爱尔兰昆士敦,驶向美国纽约。
- 4、沉船: 1912年4月15日(1912年4月14日23时40分左右撞击冰山)
船员+乘客人数:2224 - 5、遇难人数: 1502(67.5%)
1.2 目标问题
- 根据训练集中各位乘客的特征及是否获救标志的对应关系训练模型,预测测试集中的乘客是否获救。(
二元分类问题
)
1.3 字段描述
二、训练集(train)建模
- 数据集链接:train.csv
2.1 导入相关库
import numpy as np
import pandas as pd
from scipy import stats
# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
# 可视化相关库
import seaborn as sns
import matplotlib.pyplot as plt
# 解决mac 系统画图中文不显示问题
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
# # 解决win 系统中文不显示问题
# from pylab import mpl
# mpl.rcParams['font.sans-serif']=['SimHei']
# 不显示警告
import warnings
warnings.filterwarnings('ignore')
2.2 自定义函数
def PieChart(df):
'''
绘制环形饼图
'''
plt.figure(
figsize = (4,4), # 设置图片大小
dpi = 100 # 精度
)
df.value_counts().plot(
kind = 'pie', # 设置绘图类型为饼图
wedgeprops = {'width':0.4}, # 设置空心比例
autopct = "%.1f%%" # 显示百分比
)
def BarPlot(df,ColumnsName):
'''
绘制不同 ColumnsName 的存活人数柱形图
'''
ColumnsDf = df.groupby(['Survived',ColumnsName]).count()[['PassengerId']].reset_index()\
.rename(columns={"PassengerId":"Count"})
plt.figure(figsize=(4,3),dpi=150)
sns.barplot(
data=ColumnsDf,
x=ColumnsName,
y="Count",
hue="Survived"
)
plt.title('Survived Count Of {}'.format(ColumnsName))
def OneHot(x):
'''
功能:one-hot 编码
传入:需要编码的分类变量
返回:返回编码后的结果,形式为 dataframe
'''
# 通过 LabelEncoder 将分类变量打上数值标签
lb = LabelEncoder() # 初始化
x_pre = lb.fit_transform(x) # 模型拟合
x_dict = dict([[i,j] for i,j in zip(x,x_pre)]) # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}
x_num = [[x_dict[i]] for i in x] # 通过 x_dict 将分类变量转为数值型
# 进行one-hot编码
enc = OneHotEncoder() # 初始化
enc.fit(x_num) # 模型拟合
array_data = enc.transform(x_num).toarray() # one-hot 编码后的结果,二维数组形式
# 转成 dataframe 形式
df = pd.DataFrame(array_data)
inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值
# columns 重命名
if type(x) == pd.Series:
firs_name = x.name
else:
firs_name = ""
df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]
return df
2.3 特征工程
2.3.1 数据导入
train = pd.read_csv("train.csv")
train.head(5)
2.3.2 数据初探
(1)特征信息
train.info()
- 可以看出训练集共有891个样本,且有三个字段(Age、Cabin、Embarked)存在缺失值。
(2)特征缺失值比例统计
train.isnull().sum()/len(train)
- 可以看出,字段Cabin缺失比例较大,达到77%。
(3)数值特征描述统计
train.describe()
- 可以看出,票价(Fare)最低为0,估计是船上的员工。
2.3.3 单特征可视化分析与处理
(1)Survived 是否存活
########################## 1、Survived 是否存活 ##########################
# Y标签,{0:不存活,1:存活}
# 有无缺失值:无
# 数据处理:不处理
# 从图中可以看出,死亡人数与存活人数占比差异不大
PieChart(train['Survived'])
(2)Pclass 乘客等级
########################## 2、Pclass 乘客等级 ##########################
# 无缺失值,等级变量
# 用柱状图查看各乘客等级的存活情况
# 可以看出 Pclass=3 的乘客中,存活人数远低于死亡人数
BarPlot(train,"Pclass")
# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
train['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in train['Pclass']]
# 查看不同 PclassType 的存活情况
BarPlot(train,"PclassType")
# 再对 PclassType 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['PclassType']),
left_index=True,
right_index=True
)
(3)Name 乘客姓名
########################## 3、Name 乘客姓名 ##########################
# 字符串变量
# 有无缺失值:无
# 从乘客姓名中获取头街
# 姓名中头街字符串与定义头街类别之间的关系
# Officer: 政府官员,
# RoyaIty: 王室(皇室),
# Mr: 已婚男士,
# Mrs: 已婚女士,
# Miss: 年轻未婚女子,
# Master: 有技能的人/教师
# 新建字段 Title_Dict
Title_Dict = {
'Mr':'Mr',
'Mrs':'Mrs',
'Miss':'Miss',
'Master': 'Master',
'Don':'Royalty',
'Rev':'Officer',
'Dr':')fficer',
'Mme':'Mrs',
'Ms':'Mrs',
'Major':'Officer',
'Lady': 'Royalty',
'Sir': 'Royalty',
'Mlle':'Miss',
'Col': 'Officer',
'Capt':'Officer',
'the Countess': 'Royalty',
'Jonkheer': 'Royalty',
'Dona': 'Royalty'
}
train['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in train['Name']] # 对Name进行分类
# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为Mr(已婚男士)中,死亡人数远远大于存活人数;
# 乘客为Mrs(已婚女士)、Miss(年轻未婚女子)中,死亡人数远远低于存活人数;
BarPlot(train,"NameType")
# 数据进一步处理:将 NameType 分成三类
# Mr(已婚男士)
# Mrs(已婚女士)、Miss(年轻未婚女子)
# 其他
train['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \
for i in train['NameType']]
# 查看不同 NameType2 的存活情况
BarPlot(train,"NameType2")
# 再对 NameType2 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['NameType2']),
left_index=True,
right_index=True
)
(4)Sex 性别
########################## 4、Sex 性别 ##########################
# 分类变量
# 有无缺失值:无
# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为男性中,死亡人数远远大于存活人数
BarPlot(train,"Sex")
# 对 Sex 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['Sex']),left_index=True,right_index=True)
(5)Age 年龄
########################## 5、Age 年龄 ##########################
# 连续变量
# 有无缺失值:有,缺失比例19.9%
# 缺失值用均值填充
train['Age'] = train['Age'].fillna(train['Age'].mean())
# 用直方图查看各 Age 的存活情况
# 可以看出 可以看出小于5岁的小孩存活率很高
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(train[train['Survived']==0]['Age'],color="red",kde=False)
sns.distplot(train[train['Survived']==1]['Age'],color="blue",kde=False)
# 数据处理:将 Age 分成两类,Age<=5、Age>5
train['AgeType'] = ["Age<=5" if i <= 5 else "Age>5" for i in train['Age']]
# 查看不同 AgeType 的存活情况
BarPlot(train,"AgeType")
# 再对 AgeType 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['AgeType']),
left_index=True,
right_index=True
)
(6)SibSp 堂兄弟妹个数
########################## 6、SibSp 堂兄弟妹个数 ##########################
# 无缺失值,等级变量
# 用柱状图查看各堂兄弟妹个数的存活情况
# 可以看出 SibSp=0 的乘客中,死亡人数较多
BarPlot(train,"SibSp")
# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
train['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in train['SibSp']]
# 查看不同 SibSpType 的存活情况
BarPlot(train,"SibSpType")
# 再对 SibSpType 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['SibSpType']),
left_index=True,
right_index=True
)
(7)Parch 父母与小孩的个数
########################## 7、Parch 父母与小孩的个数 ##########################
# 连续变量
# 有无缺失值:无
# 用柱状图查看父母与小孩的个数的存活情况
# 可以看出 Parch=0 的乘客中,死亡人数较多
BarPlot(train,"Parch")
# 数据处理:将 Parch 分成两类,Parch=0、Parch>0
train['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in train['Parch']]
# 查看不同 ParchType 的存活情况
BarPlot(train,"ParchType")
# 再对 ParchType 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['ParchType']),
left_index=True,
right_index=True
)
(8)Ticket 船票信息
- 字符变量
- 有无缺失值:无
- 数据处理:这里直接删去(下文会删)
(9)Fare 票价
########################## 9、Fare 票价 ##########################
# 连续变量
# 有无缺失值:无
# 查看Fare(票价)= 0 的生存情况
Fare0 = train[train['Fare']==0]
Fare0Survived = Fare0.groupby(['Survived']).count()[['PassengerId']].reset_index().rename(columns={"PassengerId":"Count"})
plt.figure(figsize=(4,3),dpi=150)
sns.barplot(
data=Fare0Survived,
x="Survived",
y="Count"
)
plt.title('Survived Count Of Fare=0')
# 查看Fare(票价)!= 0 的生存情况
Fare1 = train[train['Fare']!=0]
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(Fare1[Fare1['Survived']==0]['Fare'],color="red",kde=False)
sns.distplot(Fare1[Fare1['Survived']==1]['Fare'],color="blue",kde=False)
plt.title('Survived Count Of Fare!=0')
# 对 Fare 分成三类
# Fare = 0
# Fare <=50
# Fare > 50
train['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in train['Fare']]
# 用柱状图查看不同 FareType 的存活情况
# 可以看出 Fare=0 的乘客中,乘客几乎都死亡
# Fare <=50 的乘客中,死亡人数大于存活人数
# Fare > 50 的乘客中,存活人数大于死亡人数
BarPlot(train,"FareType")
# 再对 FareType 进行One-Hot编码处理
train = pd.merge(
train,
OneHot(train['FareType']),
left_index=True,
right_index=True
)
(10)Cabin 船舱
- 离散变量
- 有无缺失值:有,缺失值比例高达77%
- 数据处理:缺失值比例较大,直接删去(下文会删)
(11)Embarked 登船的港口
########################## 11、Embarked 登船的港口 ##########################
# 离散变量
# 有无缺失值:有,缺失值比例很低
# 用柱状图查看各登船的港口的存活情况
# 可以看出 Embarked=S 的乘客中,死亡人数较多
BarPlot(train,"Embarked")
# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
mode = stats.mode(train['Embarked'])[0][0] # 众数
train['Embarked'] = train['Embarked'].fillna(mode)
train = pd.merge(train,OneHot(train['Embarked']),left_index=True,right_index=True)
2.3.4 衍生特征可视化分析与处理
FamilyNumbers 家庭人数
########################## FamilyNumbers 家庭人数 ##########################
# 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
train['FamilyNumbers'] = train['SibSp'] + train['Parch'] + 1
# 用柱状图查看各家庭人数的存活情况
# 可以看出 家庭人数=1 的乘客中,死亡人数较多
# 家庭人数>=5 的乘客中,存活人数较多
BarPlot(train,"FamilyNumbers")
# 新增 FamilyType 字段
# 1 : 单身(Single)
# 2-4:小家庭(Family_Small)
# >4: 大家庭(Family_Large)
train['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in train['FamilyNumbers']]
# 查看不同 FamilyType 的存活情况
BarPlot(train,"FamilyType")
# 对 FamilyType 进行One-hot编码处理
train = pd.merge(train,OneHot(train['FamilyType']),left_index=True,right_index=True)
2.3.5 删除冗余字段
drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\
'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\
'FamilyNumbers','FamilyType']
train.drop(drop_columns,axis=1,inplace=True)
2.3.6 相关性矩阵可视化
- 采用斯皮尔曼相关系数
corr_df = train.corr(method="spearman")[['Survived']].sort_values(by="Survived",ascending=False)
plt.figure(figsize=(1,8),dpi=100)
sns.heatmap(
corr_df,
cmap='Blues',
center=0,
vmax=1,
vmin=-1,
annot=True,
annot_kws={'size':10,'weight':'bold', 'color':'red'}
)
2.4 决策树模型训练
2.4.1 数据标准化(Z-score)
def ZscoreNormalization(x):
'''
Z-score 标准化
'''
return (x - np.mean(x)) / np.std(x)
data = train.drop("Survived",axis=1).agg(ZscoreNormalization)
data['Lable'] = train['Survived']
2.4.2 划分训练集、测试集
- 按7:3比例划分
x_train, x_test, y_train, y_test = train_test_split(
data.drop("Lable",axis=1),
data['Lable'],
test_size = 0.3,
random_state = 0
)
2.4.3 网格寻参与交叉验证
param_grid = {
'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”
'splitter' : ['best','random'], # 划分方式:{“best”, “random”}, default=”best”
'max_depth' : range(1,6), # 最大深度
'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数
'min_samples_leaf' : range(1,6), # 叶节点所需的最小样本数
}
clf = DecisionTreeClassifier() # 初始化
gs = GridSearchCV(clf,param_grid,cv=5) # 网格搜索与交叉验证
gs.fit(x_train,y_train) # 模型训练
print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
print("Best Score: ",gs.best_score_) # 打印最好分数
注意: 每次运行的结果输出会存在差别。
2.4.4 模型评价
print("\n---------- 模型评价 ----------")
y_pred = gs.predict(x_test) # 预测
cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
df_cm = pd.DataFrame(cm) # 构建DataFrame
print('Accuracy score:', accuracy_score(y_test, y_pred)) # 准确率
print('Recall:', recall_score(y_test, y_pred, average='weighted')) # 召回率
print('F1-score:', f1_score(y_test, y_pred, average='weighted')) # F1分数
print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度
2.4.5 混淆矩阵可视化
plt.figure(dpi=150)
heatmap = sns.heatmap(df_cm, annot=True, fmt='.0f', cmap='Blues')
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, ha='right')
plt.title('DecisionTreeClassifier Model Results')
plt.show()
2.4.6 ROC曲线
y_pred_proba = gs.predict_proba(np.array(x_test))[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
sns.set()
plt.figure(figsize=(5,4),dpi=150)
plt.plot(fpr, tpr)
plt.plot(fpr, fpr, linestyle = '-' , color = 'k')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
AU = np.round(roc_auc_score(y_test, y_pred_proba), 2)
plt.title(f'AU: {AU}');
plt.show()
三、完整代码(含对test预测)
- 含预测,不含可视化
import numpy as np
import pandas as pd
from scipy import stats
# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
# 不显示红色警告
import warnings
warnings.filterwarnings('ignore')
def OneHot(x):
'''
功能:one-hot 编码
传入:需要编码的分类变量
返回:返回编码后的结果,形式为 dataframe
'''
# 通过 LabelEncoder 将分类变量打上数值标签
lb = LabelEncoder() # 初始化
x_pre = lb.fit_transform(x) # 模型拟合
x_dict = dict([[i,j] for i,j in zip(x,x_pre)]) # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}
x_num = [[x_dict[i]] for i in x] # 通过 x_dict 将分类变量转为数值型
# 进行one-hot编码
enc = OneHotEncoder() # 初始化
enc.fit(x_num) # 模型拟合
array_data = enc.transform(x_num).toarray() # one-hot 编码后的结果,二维数组形式
# 转成 dataframe 形式
df = pd.DataFrame(array_data)
inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值
# columns 重命名
if type(x) == pd.Series:
firs_name = x.name
else:
firs_name = ""
df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]
return df
def ZscoreNormalization(x):
'''
Z-score 标准化
'''
return (x - np.mean(x)) / np.std(x)
def DataClean(df,Lable=True):
'''
数据预处理函数
'''
########################## 1、Pclass 乘客等级 ##########################
# 无缺失值,等级变量
# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
df['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in df['Pclass']]
# 再对 PclassType 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['PclassType']),left_index=True,right_index=True)
########################## 2、Name 乘客姓名 ##########################
# 字符串变量
# 有无缺失值:无
# 从乘客姓名中获取头街
# 姓名中头街字符串与定义头街类别之间的关系
# Officer: 政府官员,
# RoyaIty: 王室(皇室),
# Mr: 已婚男士,
# Mrs: 已婚女士,
# Miss: 年轻未婚女子,
# Master: 有技能的人/教师
# 新建字段 Title_Dict
Title_Dict = {
'Mr':'Mr',
'Mrs':'Mrs',
'Miss':'Miss',
'Master': 'Master',
'Don':'Royalty',
'Rev':'Officer',
'Dr':')fficer',
'Mme':'Mrs',
'Ms':'Mrs',
'Major':'Officer',
'Lady': 'Royalty',
'Sir': 'Royalty',
'Mlle':'Miss',
'Col': 'Officer',
'Capt':'Officer',
'the Countess': 'Royalty',
'Jonkheer': 'Royalty',
'Dona': 'Royalty'
}
df['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in df['Name']] # 对Name进行分类
# 数据进一步处理:将 NameType 分成三类
# Mr(已婚男士)
# Mrs(已婚女士)、Miss(年轻未婚女子)
# 其他
df['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \
for i in df['NameType']]
# 再对 NameType2 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['NameType2']),left_index=True,right_index=True)
########################## 3、Sex 性别 ##########################
# 分类变量
# 有无缺失值:无
# 对 Sex 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['Sex']),left_index=True,right_index=True)
########################## 4、Age 年龄 ##########################
# 连续变量
# 有无缺失值:有,缺失比例19.9%
# 缺失值用均值填充
df['Age'] = df['Age'].fillna(df['Age'].mean())
# 数据处理:将 Age 分成两类,Age<=5、Age>5
df['AgeType'] = ["Age<=5" if i <= 5 else "Age>5" for i in df['Age']]
# 再对 AgeType 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['AgeType']),left_index=True,right_index=True)
########################## 5、SibSp 堂兄弟妹个数 ##########################
# 无缺失值,等级变量
# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
df['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in df['SibSp']]
# 再对 SibSpType 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['SibSpType']),left_index=True,right_index=True)
########################## 6、Parch 父母与小孩的个数 ##########################
# 连续变量
# 有无缺失值:无
# 数据处理:将 Parch 分成两类,Parch=0、Parch>0
df['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in df['Parch']]
# 再对 ParchType 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['ParchType']),left_index=True,right_index=True)
########################## 8、Fare 票价 ##########################
# 连续变量
# 有无缺失值:无
# 对 Fare 分成三类
# Fare = 0
# Fare <=50
# Fare > 50
df['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in df['Fare']]
# 再对 FareType 进行One-Hot编码处理
df = pd.merge(df,OneHot(df['FareType']),left_index=True,right_index=True)
########################## 10、Embarked 登船的港口 ##########################
# 离散变量
# 有无缺失值:有,缺失值比例很低
# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
mode = stats.mode(df['Embarked'])[0][0] # 众数
df['Embarked'] = df['Embarked'].fillna(mode)
df = pd.merge(df,OneHot(df['Embarked']),left_index=True,right_index=True)
########################## 11、FamilyNumbers 家庭人数 ##########################
# 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
df['FamilyNumbers'] = df['SibSp'] + df['Parch'] + 1
# 新增 FamilyType 字段
# 1 : 单身(Single)
# 2-4:小家庭(Family_Small)
# >4: 大家庭(Family_Large)
df['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in df['FamilyNumbers']]
# 对 FamilyType 进行One-hot编码处理
df = pd.merge(df,OneHot(df['FamilyType']),left_index=True,right_index=True)
########################## 删除冗余变量 ##########################
drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\
'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\
'FamilyNumbers','FamilyType']
df.drop(drop_columns,axis=1,inplace=True)
########################## 数据标准化 ##########################
if Lable == True: # 判断是否是测试集(测试集不含标签)
data = df.drop("Survived",axis=1).agg(ZscoreNormalization)
data['Lable'] = df['Survived']
else:
data = df.agg(ZscoreNormalization)
return data
def sklearn_DecisionTreeClassifier(data):
'''
决策树二分类
'''
# 划分训练集、测试集
x_train, x_test, y_train, y_test = train_test_split(
data.drop("Lable",axis=1),
data['Lable'],
test_size = 0.3,
random_state = 0
)
print("\n---------- 模型训练 ----------")
# 网格寻参
param_grid = {
'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”
'splitter' : ['best','random'], # 划分方式:{“best”, “random”}, default=”best”
'max_depth' : range(1,6), # 最大深度
'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数
'min_samples_leaf' : range(1,6), # 叶节点所需的最小样本数
}
clf = DecisionTreeClassifier() # 初始化
gs = GridSearchCV(clf,param_grid,cv=5) # 网格搜索与交叉验证
gs.fit(x_train,y_train) # 模型训练
print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
print("Best Score: ",gs.best_score_) # 打印最好分数
# 模型预测
print("\n---------- 模型评价 ----------")
y_pred = gs.predict(x_test) # 预测
cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
df_cm = pd.DataFrame(cm) # 构建DataFrame
print('Accuracy score:', accuracy_score(y_test, y_pred)) # 准确率
print('Recall:', recall_score(y_test, y_pred, average='weighted')) # 召回率
print('F1-score:', f1_score(y_test, y_pred, average='weighted')) # F1分数
print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度
return gs.best_estimator_ # 返回最好的训练模型
if __name__ == "__main__":
train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")
print("\n---------- 数据预处理 ----------")
train_data = DataClean(train)
test_data = DataClean(test,Lable=False)
# 决策树二分类
best_estimator = sklearn_DecisionTreeClassifier(train_data)
# 预测
y_pred = best_estimator.predict(test_data)
# 输出预测结果
result = test[['PassengerId']]
result['Survived'] = y_pred
result.to_csv("Titanic Results.csv",index=False)
print("\n程序运行完成")
四、Kaggle 得分
- 得分:0.77511
- 排名:7651
参考:
1、Kaggle泰坦尼克号比赛项目详解
2、机器学习实战——kaggle 泰坦尼克号生存预测——六种算法模型实现与比较