泰坦尼克号:从灾难中学习机器
目标 : 使用机器学习来创建一个模型,该模型可以预测哪些乘客在泰坦尼克号沉船事故中幸存下来。
import warnings
warnings.filterwarnings("ignore")
我将探索数据并使用它们进行处理,并输入缺失的值。特征工程是机器学习过程中的重要组成部分,因此我想花更多时间在这部分上。我将尝试一些模型,并告诉记录分析哪种模型最适合本次比赛的train数据集。
导库
In [2]:
import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import math
import xgboost as xgb
np.random.seed(2019)
from scipy.stats import skew
from scipy import stats
import statsmodels
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
%matplotlib inline
print("done")
done
导数据
我在此处添加“train”变量,以便稍后以最简单的方式检查train和测试数据集中的观测值,因为我要加入train和测试数据集。
In [3]:
def read_and_concat_dataset(training_path, test_path):
train = pd.read_csv(training_path)
train['train'] = 1
test = pd.read_csv(test_path)
test['train'] = 0
data = train.append(test, ignore_index=True)
return train, test, data
train, test, data = read_and_concat_dataset('../input/Titanic/train.csv', '../input/Titanic/test.csv')
data = data.set_index('PassengerId')
Explore the Data
In [4]:
data.head(5)
Out[4]:
Age | Cabin | Embarked | Fare | Name | Parch | Pclass | Sex | SibSp | Survived | Ticket | train | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
PassengerId | ||||||||||||
1 | 22.0 | NaN | S | 7.2500 | Braund, Mr. Owen Harris | 0 | 3 | male | 1 | 0.0 | A/5 21171 | 1 |
2 | 38.0 | C85 | C | 71.2833 | Cumings, Mrs. John Bradley (Florence Briggs Th… | 0 | 1 | female | 1 | 1.0 | PC 17599 | 1 |
3 | 26.0 | NaN | S | 7.9250 | Heikkinen, Miss. Laina | 0 | 3 | female | 0 | 1.0 | STON/O2. 3101282 | 1 |
4 | 35.0 | C123 | S | 53.1000 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 1 | female | 1 | 1.0 | 113803 | 1 |
5 | 35.0 | NaN | S | 8.0500 | Allen, Mr. William Henry | 0 | 3 | male | 0 | 0.0 | 373450 | 1 |
In [5]:
data.describe()
Out[5]:
Age | Fare | Parch | Pclass | SibSp | Survived | train | |
---|---|---|---|---|---|---|---|
count | 1046.000000 | 1308.000000 | 1309.000000 | 1309.000000 | 1309.000000 | 891.000000 | 1309.000000 |
mean | 29.881138 | 33.295479 | 0.385027 | 2.294882 | 0.498854 | 0.383838 | 0.680672 |
std | 14.413493 | 51.758668 | 0.865560 | 0.837836 | 1.041658 | 0.486592 | 0.466394 |
min | 0.170000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 |
25% | 21.000000 | 7.895800 | 0.000000 | 2.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 28.000000 | 14.454200 | 0.000000 | 3.000000 | 0.000000 | 0.000000 | 1.000000 |
75% | 39.000000 | 31.275000 | 0.000000 | 3.000000 | 1.000000 | 1.000000 | 1.000000 |
max | 80.000000 | 512.329200 | 9.000000 | 3.000000 | 8.000000 | 1.000000 | 1.000000 |
参数介绍
PassengerId : 该行的唯一ID,对生存值没有任何影响。
Survived - 二进制文件:
- 1 -> 幸存
- 0 -> 无法幸存
Pclass (Passenger Class) - 乘客的经济状况,此变量具有3个值
- 1 -> 上层阶级
- 2 -> 中产阶级
- 3 -> 下层阶级
Name - 性别和年龄
SibSp - 乘客的兄弟姐妹和配偶的总数
Parch -乘客父母和子女的总数
Ticket - 票号
Fare - 乘客票价.
Cabin - 机舱号
Embarked :登船口岸, 3 values:
- C -> Cherbourg
- Q -> Queenstown
- S -> Southampton
数值之间的相关矩阵
In [6]:
g = sns.heatmap(data[["Survived","SibSp","Parch","Age","Fare"]].corr(),annot=True, cmap = "coolwarm")
数值变量和Survived之间的相关性不是很高,但也是有用的
In [7]:
def comparing(data,variable1, variable2):
print(data[[variable1, variable2]][data[variable2].isnull()==False].groupby([variable1], as_index=False).mean().sort_values(by=variable2, ascending=False))
g = sns.FacetGrid(data, col=variable2).map(sns.distplot, variable1)
In [8]:
def counting_values(data, variable1, variable2):
return data[[variable1, variable2]][data[variable2].isnull()==False].groupby([variable1], as_index=False).mean().sort_values(by=variable2, ascending=False)
Parch vs Survived
In [9]:
comparing(data, 'Parch','Survived')
Parch Survived
3 3 0.600000
1 1 0.550847
2 2 0.500000
0 0 0.343658
5 5 0.200000
4 4 0.000000
6 6 0.000000
SibSp vs Survived
In [10]:
comparing(data, 'SibSp','Survived')
SibSp Survived
1 1 0.535885
2 2 0.464286
0 0 0.345395
3 3 0.250000
4 4 0.166667
5 5 0.000000
6 8 0.000000
Fare vs Survived
In [11]:
comparing(data, 'Fare','Survived')
Fare Survived
247 512.3292 1.0
196 57.9792 1.0
89 13.8583 1.0
88 13.7917 1.0
86 13.4167 1.0
..