前言:Titanic生存率预测是Kaggle平台上的经典竞赛项目,本文通过该项目展示了运用机器学习方法分析、解决问题的一般思路:即首先应明确要分析的问题和项目的目的,在搜集整理所需数据并理解数据之间的关系后,进一步对数据进行预处理以提升数据的质量,接下来将数据导入模型进行分析,比较不同模型效果并对模型进行评价,最终选择合适的模型对问题进行分析或预测。
1.背景介绍
泰坦尼克号于1909年3月31日在爱尔兰动工建造,1911年5月31日下水,次年4月2日完工试航。她是当时世界上体积最庞大、内部设施最豪华的客运轮船,有“永不沉没”的美誉。然而讽刺的是,泰坦尼克号首航便遭遇厄运:1912年4月10日她从英国南安普顿出发,途径法国瑟堡和爱尔兰昆士敦,驶向美国纽约。在14日晚23时40分左右,泰坦尼克号与一座冰山相撞,导致船体裂缝进水。次日凌晨2时20分左右,泰坦尼克号断为两截后沉入大西洋,其搭载的2224名船员及乘客,在本次海难中逾1500人丧生。
在学习机器学习相关项目时,Titanic生存率预测项目也通常是入门练习的经典案例。Kaggle平台为我们提供了一个竞赛案例“Titanic: Machine Learning from Disaster”,在该案例中,我们将探究什么样的人在此次海难中幸存的几率更高,并通过构建预测模型来预测乘客生存率。Titanic: Machine Learning from Disasterwww.kaggle.com
本文通过数据可视化理解数据,并利用特征工程等方法挖掘更多有价值的特征,然后利用同组效应找出共性较强的群体并对其数据进行修正,在选择模型时分别比较了Gradient Boosting Classifier、Logistic Regression等多种方法,最终利用Gradient Boosting Classifier对乘客的生存率进行预测。
最终,在kaggle上模型的得分为0.82775,328名,排名约为TOP3%。
2.数据准备
本文数据集是来源于Kaggle平台中“Titanic: Machine Learning from Disaster”竞赛项目,数据字段释义如下:
首先,导入基础包来搭建项目分析的环境,并加载相关数据。
#导入相关包
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import seaborn as sns
#设置sns样式
sns.set(style='white',context='notebook',palette='muted')
import matplotlib.pyplot as plt
#导入数据
train=pd.read_csv('Python数据/train.csv')
test=pd.read_csv('Python数据/test.csv')
3.理解数据
3.1 查看数据情况
1)查看数据量及数据特征,理解特征含义:
#分别查看实验数据集和预测数据集数据
print('实验数据大小:',train.shape)
print('预测数据大小:',test.shape)
>>>
实验数据大小: (891, 12)
预测数据大小: (418, 11)
该数据集共1309条数据,其中实验数据891条,预测数据418条;实验数据比预测数据多了一列:即标签"result"。
2)记录数据异常值、缺失值情况,方便下一步进行数据预处理。
#将实验数据和预测数据合并
full=train.append(test,ignore_index=True)
full.describe()
无明显的异常值,几乎所有数据均在正常范围内。
full.info()
>>>
RangeIndex: 1309 entries, 0 to 1308
Data columns (total 12 columns):
Age 1046 non-null float64
Cabin 295 non-null object
Embarked 1307 non-null object
Fare 1308 non-null float64
Name 1309 non-null object
Parch 1309 non-null int64
PassengerId 1309 non-null int64
Pclass 1309 non-null int64
Sex 1309 non-null object
SibSp 1309 non-null int64
Survived 891 non-null float64
Ticket 1309 non-null object
dtypes: float64(3), int64(4), object(5)
memory usage: 122.8+ KB
Age/Cabin/Embarked/Fare四项数据有缺失值,其中Cabin字段缺失近四分之三的数据。
3.2 查看特征与标签间关系
结合图表查看各个特征与标签间的关系。
3.2.1 Embarked与Survived:法国登船的乘客生存率较高
sns.barplot(data=train,x='Embarked',y='Survived')
#计算不同类型embarked的乘客,其生存率为多少
print('Embarked为"S"的乘客,其生存率为%.2f'%full['Survived'][full['Embarked']=='S'].value_counts(normalize=True)[1])
#'C','Q'代码类同'S',这里不赘述啦Embarked为"S"的乘客,其生存率为0.34
Embarked为"C"的乘客,其生存率为0.55
Embarked为"Q"的乘客,其生存率为0.39
法国登船乘客生存率较高原因可能与其头等舱乘客比例较高有关,因此继续查看不同登船地点乘客各舱位乘客数量情况。
#法国登船乘客生存率较高原因可能与其头等舱乘客比例较高有关
sns.factorplot('Pclass',col='Embarked',data=train,kind='count',size=3)
果然,法国登船的乘客其头等舱所占比例更高
3.2.2 Parch与Survived:当乘客同行的父母及子女数量适中时,生存率较高
sns.barplot(data=train,x='Parch',y='Survived')
3.2.3 SibSp与Survived:当乘客同行的同辈数量适中时生存率较高
sns.barplot(data=train,x='SibSp',y='Survived')
3.2.4 Pclass与Survived:乘客客舱等级越高,生存率越高
sns.barplot(data=train,x='Pclass',y='Survived')
3.2.5 Sex与Survived:女性的生存率远高于男性
sns.barplot(data=train,x='Sex',y='Survived')
3.2.6 Age与Survived:当乘客年龄段在0-10岁期间时生存率会较高
#创建坐标轴
ageFacet=sns.FacetGrid(train,hue='Survived',aspect=3)
#作图,选择图形类型
ageFacet.map(sns.kdeplot,'Age',shade=True)
#其他信息:坐标轴范围、标签等
ageFacet.set(xlim=(0,train['Age'].max()))
ageFacet.add_legend()
3.2.7 Fare与Survived:当票价低于18左右时乘客生存率较低,票价越高生存率一般越高
#创建坐标轴
ageFacet=sns.FacetGrid(train,hue='Survived',aspect=3)
ageFacet.map(sns.kdeplot,'Fare',shade=True)
ageFacet.set(xlim=(0,150))
ageFacet.add_legend()
查看票价的分布特征
#查看fare分布
farePlot=sns.distplot(full['Fare'][full['Fare'].notnull()],label='skewness:%.2f'%(full['Fare'].skew()))
farePlot.legend(loc='best')
fare的分布呈左偏的形态,其偏度skewness=4.37较大,说明数据偏移平均值较多,因此我们需要对数据进行对数化处理,防止数据权重分布不均匀。
#对数化处理fare值
full['Fare']=full['Fare'].ma