【Kaggle】Titanic幸存者预测(2019年1月第一次尝试)

Kaggle 专栏收录该内容
1 篇文章 0 订阅

前言

泰坦尼克的沉没是历史上著名的沉船事故之一。1912年4月15日,被认为是“永不沉没”的泰坦尼克号渡轮驶上了它的处女航。然而不幸的是却在航行途中撞上了冰山面临沉船。更不幸的是船上并没有配备足够的逃生船,导致了2224名乘客中有1502名乘客葬身大海。虽然有一些幸运的幸存者活了下来,但仔细分析一下不难发现有那么一群人有着更大的可能活下来。因此需要通过电脑来预测一下什么样的人群能够幸存。

泰坦尼克号幸存者预测是以著名泰坦尼克号沉默的悲剧为背景的二分类问题。它提供了泰坦尼克当时船员的数据,一共包含891个训练样本(train.csv)和418个测试样本(test.csv),要求使用891个样本训练模型,以预测418个测试样本幸存(Survived)与否,其中1表示幸存,0表示遇难。


一、问题目标

  • 利用Python实现对数据的清洗、转换及特征工程,利用相关性找出影响乘客生还的关键因素,最后建立模型预测出哪些乘客能够幸存,并得出预测准确率,将结果提交至kaggle上获得评分。
  • 问题链接:Kaggle - Titanic
  • Github代码:Github

二、问题分析流程设计

问题分析流程设计


三、数据处理

3.1 数据集初步探索

3.1.1 数据集获取

  • 从网站中获取两个数据集:train.csv及test.csv,其中train为训练集,test为测试集。
  • Python第三方库导入
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import chi2_contingency
from sklearn.model_selection import StratifiedKFold, learning_curve
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from sklearn import svm
from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits
from sklearn.model_selection import cross_val_score
import seaborn as sns

3.1.2 数据集基础情况

  • 从网站上给的Data信息可以了解数据集中各标签的含义,如下表所示。
VariableDefinitionKey
survivalSurvival0 = No;1 = Yes
pclassTicket class1 = 1st;2 = 2nd;3 = 3rd
sexSex
AgeAge in years
sibsp# of siblings / spouses aboard the Titanic
parch# of parents / children aboard the Titanic
ticketTicket number
farePassenger fare
cabinCabin number
embarkedPort of EmbarkationC = Cherbourg;Q = Queenstown; S = Southampton
  1. 第一步:通过read_csv()方法分别读入两数据集,分别查看训练集和测试集的数据情况,
data_train = pd.read_csv("train.csv")
data_test = pd.read_csv("test.csv")
  1. 第二步:使用head( )方法查看训练集的前五行如2图所示,测试集的前五行如下图所示。从中可以看出,数据共有11个特征和1个标签,Survived列为标签,其中0代表死亡,1代表存活,这也是要求预测的数据。
print(data_train.head())
print(data_test.head())
  • 训练集前五行展示图
    训练集前五行展示图
  • 测试集前五行展示图
    测试集前五行展示图
  1. 第三步:使用shape()方法查看两数据集的形状。可以看出,训练集数据共有891行12列,测试集数据有418行,11列。
print(data_train.shape, data_test.shape)
  • 运行结果:
    数据集的
  1. 第四步,使用info( )方法获取数据的信息情况。
print(data_train.info())
print(data_test.info())
  • 训练集的信息情况图
    训练集的信息情况图
  • 测试集的信息情况图
    测试集的信息情况图
  • 从两图中可以看出数据类型的基本信息,训练集的数据共有2个双精度浮点型数据,5个整型数据,另外5个均为字符型数据,测试集缺少标签Survived列,与训练集相比少一个整形数据。(若将训练集和测试集合并,由于测试集缺少Survived列,合并后的Survived数据的类型将变为双精度浮点型数据,因此合并后的数据将有3个浮点型数据、4个整形数据和5个字符型数据,与原始数据不大相同),另外还可以看出内存占用情况。
  • 此外,从中也可以初步判断数据的缺失情况,训练集共891行,其中Age、Cabin和Embarked分别仅有714、204和889个数据,可初步认定数据缺失;测试集共418行,其中Age和Fare和Embarked分别仅有332、417、91个数据,可初步认定数据缺失。
  1. 第五步,使用describe( )方法查看数据情况。
print(data_train.describe())
print(data_test.describe())
  • 训练集的数据情况图
    训练集的数据情况图

  • 测试集的数据情况图
    测试集的数据情况图

  • 从两图中可以看到数据各特征的有效数据数目、平均值、标准差、最小值、上四分位数、中位数、下四分位数、最大值,初步了解数据特征。同时根据有效数据数目的情况也可初步判断缺失值,方法与上述相同。


3.2 数据预处理

3.2.1 数据探索

3.2.1.1 性别(Sex)

  • 画出直方图
color = {0: 'r', 1: 'g'}
sns.countplot(x='Sex', data=data_train, hue='Survived', palette=color)
plt.show()

不同性别的生存人数

  • 卡方检验验证
sex_survived_pivot_table = pd.pivot_table(
    data_train,
    index='Sex',
    columns='Survived',
    values='PassengerId',
    aggfunc='count')
print(sex_survived_pivot_table)
print(chi2_contingency(sex_survived_pivot_table.values)[1])
  • 最终输出的检验的值为1.19×e-58,远小于0.05,因此可以判断性别与存活与否有很强的相关性。

3.2.1.2 年龄(Age)

  • 绘制各年龄段的生存人数分布图代码
fig, axes = plt.subplots(nrows=1, ncols=2)
fig.set_size_inches(18, 6)
data_train[data_train['Survived'] == 1]['Age'].hist(color='g', ax=axes[0])
axes[0].set_title('Survived Age hist', size=18)
data_train[data_train['Survived'] == 0]['Age'].hist(color='r', ax=axes[1])
axes[1].set_title('Died Age hist', size=18)
plt.show()
  • 各年龄段的生存人数分布图
    各年龄段的生存人数分布图
  • 从图中所给出的整体比重来看,孩童存活的比重和人数均高于遇害,年轻人存活和遇害的比重都较大,存活比重要高于遇害比重,30-40岁的人存活比重要高于遇害,之后中年人存活和遇害比重相当,老年人的遇害比重要高于存活。由此可得,年龄与存活与否有较大的相关性。

3.2.1.3 票价(Fare)

  • 票价与生存与否之间关系的箱线图
    票价与生存与否之间关系的箱线图
  • 忽略部分异常值后,可发现票价越高,存活的人数就越多,由图12来看该现象较为明显。由此可得票价与存活有较强相关性。
  • 票价与生存概率的分布图
    票价与生存概率的分布图
  • 从图中可以看出,遇害人数在票价较低的人群中占有很大的比例,而票价较高的人群中,遇害人数较少,进一步证实了票价与存活有较强的相关性。
  • 绘制生存人数与票价之间的关系图代码
plt.figure(figsize=(18, 8))
sns.distplot(data_train[data_train['Survived'] == 1]['Fare'], color='g')
sns.distplot(data_train[data_train['Survived'] == 0]['Fare'], color='r')
plt.show()

3.2.1.4 船舱等级(Pclass)

  • 船舱等级与生存人数关系的直方图
    船舱等级与生存人数关系的直方图
  • 从图中可以看出,船舱等级较低的三等舱遇害人数显著多余另外两舱,而一等船舱的存活人数高于遇害人数,这说明船舱等级与存活有较强相关性。
  • 代码如下:
color = {0: 'r', 1: 'g'}
sns.countplot(x='Pclass', hue='Survived', palette=color, data=data_train)
plt.show()
  • 同样对其进行卡方验证,算出的值为4.549251711298793×e-23,远小于0.05,证实相关性极强。
  • 代码如下:
pclass_survived_pivot_table = pd.pivot_table(  # 卡方检验验证
    data_train,
    index='Pclass',
    columns='Survived',
    values=['PassengerId'],
    aggfunc='count')
print(pclass_survived_pivot_table)
print(chi2_contingency(pclass_survived_pivot_table.values)[1])
  • 另一方面,船舱等级与票价之间也有一定关系。
    船舱等级与票价之间关系的箱线图
  • 图中的箱线图较扁,不是很适合观察。不过还是能够看出,船舱等级高票价高的存活人数多,同时说明说明船舱等级和票价并非独立,两者之间有联系,符合常识,两者与存活均有较强相关性。
  • 代码如下
color = {0: 'r', 1: 'g'}
plt.figure(figsize=(18, 6))
sns.boxplot(x='Pclass', y='Fare', data=data_train, hue='Survived', palette=color)
plt.show()

3.2.1.5 登船港口(Embarked)

  • 登船港口与生存人数关系的直方图
    登船港口与生存人数关系的直方图
  • 从图中可以看出,在C港口登船的存活比例较大,S港口登船遇害人数较多,说明登船港口与存活有较强相关性。同样对其进行卡方验证,得到的检测数据为1.769922284120912×e-06,远小于0.05,证实了强相关性。
  • 代码如下:
color = {0: 'r', 1: 'g'}
sns.countplot(x='Embarked', hue='Survived', palette=color, data=data_train)
plt.show()

embarked_survived_pivot_table = pd.pivot_table(
    data=data_train,
    index='Embarked',
    columns='Survived',
    values='PassengerId',
    aggfunc='count'
)
print(embarked_survived_pivot_table)
print(chi2_contingency(embarked_survived_pivot_table.values)[1])

  • 另一方面,登船港口与票价之间也有一定的关系。
    登船港口与票价之间关系的箱线图

  • 从图中可以看出,C港口登船票价较高的存活人数较多,说明港口与票价两者有一定联系,与存活有较强相关性。

3.2.1.6 数据探索总结

  • 综上所述,很多特征都与存活有着较强的关系,例如登船者的性别、年龄、票价、船舱、登船港口等。这些标签都可以作为关键的数据,不过也有一些标签是关系不大的,可在之后进行删除。另外,乘船是一个很有生活色彩的话题,特征之间的关系及其与目标变量之间的关系也可以根据生活经验来判断,一些较难用数据判断的相关性也可用生活经验来推断。

3.2.2 数据清洗

  • 分析出不同特征对生存结果的影响程度之后,接下来要对数据进行清洗,使得数据变得规整,才能够输入机器学习的算法。

3.2.2.1 数据合并

  • 首先记录下训练集的标签数据,用于之后训练数据,之后将其从数据集中剔除,同时记录下测试集的ID信息,用于提交结果时使用,之后将训练集和测试集合并。代码如下:
Y_train = data_train.Survived
PassengerId = data_test.PassengerId
data_train.drop(['Survived'], axis=1, inplace=True)
combined = pd.concat([data_train, data_test], sort=False, axis=0)
combined.drop(['PassengerId'], inplace=True, axis=1)
print(combined.shape)

3.2.2.2 缺失值处理

  • 在数据集初步探索的过程中,可初步判断四个特征值存在缺失值,分别是Age、Fare、Embarked和Cabin。现对合并的数据进行进一步判断,对合并后数据缺失值使用is null( )方法进行判断,判断的结果如图所示。
print(combined.isnull().sum())

合并后的数据集中缺失值的数量

  1. Age缺失值的处理
  • Age缺失数目为263个,数目较多,且其为连续性变量,因此可用平均数来填充缺失值。若用数据整体的平均数来处理相应的值,可能会有较大的误差,因此可先将数据分组,以减小误差。
  • 根据实际情况考虑,舱位的等级在一定程度上与年龄有关,年长的人一般处于较高的等级,年轻人一般处于较低的登记。其次,根据数据的观察,可以发现乘客的姓名都带有前缀,姓名的前缀也是年龄的考虑因素,例如通常情况下,Mrs要年长于Miss。除此之外,性别也可以作为一个参考因素,一般情况下,乘坐豪华邮轮的男女的年龄并不会是相同的。
  • 因此,将结合性别、船舱等级和姓名,对年龄的缺失值进行填充。首先,提取名字中的前缀,对前缀进行分组,获得Title信息。之后根据乘客的性别、船舱等级和姓名对乘客进行分组,计算出每组乘客年龄的平均值,将其作为填充的标准。
def select_group_age_median(row):
    condition = ((row['Sex'] == age_group_mean['Sex']) &
                 (row['Pclass'] == age_group_mean['Pclass']) &
                 (row['Title'] == age_group_mean['Title']))
    return age_group_mean[condition]['Age'].values[0]


def age_bin(x):
    if x <= 18:
        return 'child'
    elif x <= 30:
        return 'young'
    elif x <= 55:
        return 'midlife'
    else:
        return 'old'


age_group_mean = combined.groupby(['Sex', 'Pclass', 'Title'])['Age'].mean().reset_index()

根据Sex、Pclass和Title分组的结果

  • 数据共分为17组,每一组都有一个Age的平均值。经观察,年龄符合乘客的特征,可作为填充数据的标准。
  • 之后找到缺失数据的乘客,判断其类型,填充该类型的年龄平均值。
combined['Age'] = combined.apply(lambda x: select_group_age_median(x) if np.isnan(x['Age']) else x['Age'], axis=1)
  1. Fare缺失值的处理
  • Fare也为连续型变量,实际上同样受到其他标签,如船舱等级的影响,但Fare仅有一个缺失值,对数据整体的影响不大,因此可直接将Fare整体的平均值填充到缺失值当中。
combined['Fare'].fillna(combined['Fare'].mean(), inplace=True)
  1. Cabin缺失值的处理
  • Cabin为离散型数据,应用众数来填充,但Cabin缺失数据的数目过多,共1014个缺失值,仅有295个数据,用295个数据的众数作为1014个数据的填充值是很不合理的,会有很大的误差,由于并没有很好的方法来解决这一问题,因此也没有进行直接的填充,而是改变了Cabin的表现形式。缺失值用no表示,非缺失值用yes表示,运行结果为缺失值和非缺失值各自的数目,检测时不会出现缺失值。
combined.loc[combined['Cabin'].notnull(), 'Cabin'] = 'yes'
combined.loc[combined['Cabin'].isnull(), 'Cabin'] = 'no'
  1. Embarked缺失值的处理
  • Embarked为离散性数据,且缺失数据只有两个,因此可用众数来填充缺失值。
combined['Embarked'].fillna(combined['Embarked'].mode(), inplace=True)

3.2.2.3 数据类型处理

  • 缺失值全部处理完成后,下一个问题是,由于部分连续性变量与目标变量之间并不是线性关系,为了简化问题,需要将其转变为离散型变量。
  • 在本次问题中,年龄便是这样一种变量。因此可以将它转化为离散型变量,可将年龄分段,如青年、中年、老年,更新年龄信息数据,将年龄标签转化为离散型变量。
def age_bin(x):
    if x <= 18:
        return 'child'
    elif x <= 30:
        return 'young'
    elif x <= 55:
        return 'midlife'
    else:
        return 'old'


combined['age_bin'] = combined['Age'].map(age_bin)
  • 另外, 观察数据可以发现,兄弟姐妹、父母子女两个特征其实可以作为一类信息,都属于家庭成员,因此可以将两个特征合为一个特征:家庭人数。
  • 家庭人数也为连续型变量,与目标变量也不是线性关系,因此也可将家庭人数分段,如大型家庭、小型家庭等,更新家庭人数标签,将家庭人数标签转变为离散变量。
def deal_with_family_size(num):
    if num == 1:
        return 'Singleton'
    elif num <= 4:
        return 'SmallFamily'
    elif num >= 5:
        return 'LargeFamily'
    return num


combined['FamilySize'] = combined['SibSp'] + combined['Parch'] + 1
combined['FamilySize'] = combined['FamilySize'].map(deal_with_family_size)
  • 最后,由于字符型的数据是不能直接直接输入的,故需要将其进行转化,进行特征值的提取和转化及数据集的分离。
  • 性别、船舱、港口、姓名都是字符型数据,均需要转化。船舱等级没有明显的数字特征,也需要对其进行转化,此外,第四步操作中的离散型年龄和家庭成员数目均为字符型数据,均需要转化。
  • 七项标签中,姓名只有头衔为有用的部分,已在年龄填充时独立出来,可用头衔来代替姓名。
  • 由于票号不会影响到目标变量,将其剔除。
combined['title'] = combined['Name'].map(lambda x: x.split(',')[1].split('.')[0].strip())
Title_Dictionary = {
    "Mr": "Mr",
    "Mrs": "Mrs",
    "Miss": "Miss",
    "Master": "Master",
    "Don": "Royalty",
    "Rev": "Officer",
    "Dr": "Officer",
    "Mme": "Mrs",
    "Ms": "Mrs",
    "Major": "Officer",
    "Lady": "Royalty",
    "Sir": "Royalty",
    "Mlle": "Miss",
    "Col": "Officer",
    "Capt": "Officer",
    "the Countess": "Royalty",
    "Jonkheer": "Royalty",
    "Dona": 'Mrs'
}
combined['Title'] = combined['title'].map(Title_Dictionary)

combined = pd.get_dummies(
    combined,
    columns=['Sex', 'Cabin', 'Pclass', 'Embarked', 'Title', 'FamilySize', 'age_bin'],
    drop_first=True)
combined.drop(['Ticket'], axis=1, inplace=True)
  • 截止到现在,数据都是规整的了,接着进行数据处理的最后一步,将数据根据训练集和测试集原本的数据量,将数据重新分开,重新划分训练集和测试集。
X_train = combined.iloc[:891]
X_test = combined.iloc[891:]
  • 此时检查一下各特征值的相关性。特征对应的值越大,说明特征与变量的相关性越强。从图中可以看出,费用、年龄与目标变量有很强的相关性,姓名头衔为王室、年龄为老人区间与目标变量的相关性较小。
  • 由于总体特征值数量较少,且前面操作已经去除部分特征值,故为了数据的准确,不再进行降维操作。
  • 代码如下:
feature_importance = pd.Series(rfc.feature_importances_, X_train.columns)
feature_importance.sort_values(ascending=False, inplace=True)
print(feature_importance)

feature_importance.plot(kind='barh')
plt.savefig("不同特征对影响结果的重要性")
plt.show()

不同特征对影响结果的重要性


四、构建分析模型及预测

4.1 模型构建

  • 使用随机森林算法构建模型的第一层。随机森林建立了多个决策树,采用随机有放回的选择训练数据然后构造分类器,最后组合学习到的模型来增加整体的效果,将它们合并在一起以获得更准确和稳定的预测。随机森林的方法既对训练样本进行了采样,又对特征进行了采样,充分保证了所构建的每个树之间的独立性,使得预测结果更准确。
    随机森林算法的示意图
  • 接下来,添加网格搜索开关,形成模型的第二层。网格搜索法是指定参数值的一种穷举搜索方法,通过将估计函数的参数通过交叉验证的方法进行优化来得到最优的学习算法。即将各个参数可能的取值进行排列组合,列出所有可能的组合结果生成“网格”。然后将各组合用于SVM训练,并使用交叉验证对表现进行评估。在拟合函数尝试了所有的参数组合后,返回一个合适的分类器,自动调整至最佳参数组合。
    网格搜索法的示意图
  • 代码如下:
run_gs = True

if run_gs:
    parameter_grid = {
        'max_depth': [4, 6, 8],
        'n_estimators': [50, 10],
        'max_features': ['sqrt', 'auto', 'log2'],
        'min_samples_split': [2, 3, 10],
        'min_samples_leaf': [1, 3, 10],
        'bootstrap': [True, False],
    }
    forest = RandomForestClassifier()
    cross_validation = StratifiedKFold(n_splits=5)

    grid_search = GridSearchCV(
        forest,
        scoring='accuracy',
        param_grid=parameter_grid,
        cv=cross_validation,
        verbose=1)

    grid_search.fit(train_reduced, Y_train)
    # model = grid_search
    parameters = grid_search.best_params_

    print('Best score: {}'.format(grid_search.best_score_))
    print('Best parameters: {}'.format(grid_search.best_params_))
else:
    parameters = {
        'bootstrap': False,
        'min_samples_leaf': 1,
        'n_estimators': 10,
        'min_samples_split': 3,
        'max_features': 'log2',
        'max_depth': 8
    }

print(parameters)

model = RandomForestClassifier(**parameters)
model.fit(X_train, Y_train)
y_predict = model.predict(X_test)
res = pd.DataFrame({'PassengerId': PassengerId, 'Survived': y_predict})
res.to_csv('预测结果.csv', index=False)

4.2 模型验证

  • 对于模型的验证,采用了学习曲线和验证曲线的方式进行验证。
  • 学习曲线是一种用来判断训练模型的一种方法,通过观察绘制出来的学习曲线图,我们可以比较直观的了解到我们的模型处于一个什么样的状态,如:过拟合或欠拟合。一个比较理想的学习曲线图应当是:低偏差、低方差,即收敛且误差小。
    学习曲线的示意图
  • 验证曲线表现的是模型得分与模型复杂度的关系,学习曲线表现的是模型得分与训练数据集规模的关系。
  • 具有高偏差的模型被认为是对数据欠拟合,模型在验证集的表现与在训练及的表现类似;具有高方差的模型被认为是对数据过拟合,模型在验证集的表现远不如在训练集的表现。
    验证曲线的示意图
  • 下图为本模型的学习曲线。一个学习曲线显示一个估计量的训练分数和验证分数随着训练样本量的变化情况。学习曲线可以帮助我们找出增加更多训练数据的受益程度,估计量是否遭遇方差或偏差误差。图中的学习曲线表明,在本预测过程中,随着样本数据集规模的增大,模型得分逐渐提高,在一定程度后趋于稳定。
    学习曲线图
  • 代码如下:
def plot_learning_curve(estimator, title, x, y, ylim=None, cv=None, n_jobs=1,
                        train_sizes=np.linspace(.05, 1., 20), verbose=0, plot=True):
    """
    画出data在某模型上的learning curve.
    参数解释
    ----------
    estimator : 用的分类器。
    title : 表格的标题。
    X : 输入的feature,numpy类型
    y : 输入的target vector
    ylim : tuple格式的(ymin, ymax), 设定图像中纵坐标的最低点和最高点
    cv : 做cross-validation的时候,数据分成的份数,其中一份作为cv集,其余n-1份作为training(默认为3份)
    n_jobs : 并行的的任务数(默认1)
    """
    train_sizes, train_scores, test_scores = learning_curve(
        estimator, x, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes, verbose=verbose)

    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    if plot:
        plt.figure()
        plt.title(title)
        if ylim is not None:
            plt.ylim(*ylim)
        plt.xlabel(u"Sample Amount")
        plt.ylabel(u"Score")
        plt.gca().invert_yaxis()
        plt.grid()

        plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std,
                         alpha=0.1, color="b")
        plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std,
                         alpha=0.1, color="r")
        plt.plot(train_sizes, train_scores_mean, 'o-', color="b", label=u"Train Set Score")
        plt.plot(train_sizes, test_scores_mean, 'o-', color="r", label=u"Cross-Validation Score")

        plt.legend(loc="best")

        plt.draw()
        plt.gca().invert_yaxis()
        plt.show()

    midpoint = ((train_scores_mean[-1] + train_scores_std[-1]) + (test_scores_mean[-1] - test_scores_std[-1])) / 2
    diff = (train_scores_mean[-1] + train_scores_std[-1]) - (test_scores_mean[-1] - test_scores_std[-1])
    return midpoint, diff


plot_learning_curve(rfc, u"Learning Curve", X_train, Y_train)
  • 下图为本模型的验证曲线。从图中可以看出随着内核参数gamma的增大,交叉验证曲线存在先上升后下降的过程。当gamma趋近10-3.5时,能够达到最好的效果。
    验证曲线图
  • 代码如下:
igits = load_digits()
X = digits.data
Y_train = digits.target
param_range = np.logspace(-6, -1, 5)
vsc = svm.SVC()
train_score, test_score = validation_curve(vsc, X, Y_train,
                                           param_name='gamma',
                                           param_range=param_range,
                                           cv=10,
                                           scoring="accuracy",
                                           n_jobs=1)
train_score_mean = np.mean(train_score, axis=1)
train_score_std = np.std(train_score, axis=1)
test_score_mean = np.mean(test_score, axis=1)
test_score_std = np.std(test_score, axis=1)
plt.title("Validation curve with SVM")
plt.xlabel("Gamma")
plt.ylabel("Score")
plt.ylim()
lw = 2
plt.semilogx(param_range, train_score_mean, label="Training score", color="r", lw=lw)
plt.semilogx(param_range, test_score_mean, label="Cross-validation", color="g", lw=lw)
plt.fill_between(param_range, train_score_mean - train_score_std, train_score_mean + train_score_std, alpha=0.2,
                 color="navy", lw=lw)
plt.legend(loc="best")
plt.show()

4.3 模型评价

  • 预测结果如下图所示。其中Best score为取得的最高准确率,Best parameters为通过网格搜索验证算法取得的最优参数。
    程序预测结果输出图
  • 完成功能后,提交上kaggle进行测试。
    Kaggle结果

总结

  • 本模型以大量决策树构成随机森林为基础构建,通过网格搜索验证算法取得最优参数,能够完整实现使用训练样本训练模型,并利用其预测幸存者的功能,在训练数据集上的得分是0.8103,在测试数据集上的得分是0.7894,说明模型具有不错的泛化功能。
  • 模型可能存在的一些有待改进之处:
    1. 特征工程方面,对各特征数据两两之间的相关性分析不足,如果利用特征数据之间的相关性强化数据的补充和筛选,可能对提高预测准确率有所帮助。
    2. 通过网格搜索获取的参数已经是最优参数了,如果追求更高的预测准确率,可以改用更复杂的模型,如使用Boosting,Stacking方法将AdaBoost、ExtraTrees、GBDT、KNN等基学习器添加至本模型中,可能对提高预测准确率有所帮助。

有用的话记得一键三连哦!!

  • 1
    点赞
  • 3
    评论
  • 6
    收藏
  • 打赏
    打赏
  • 扫一扫,分享海报

©️2022 CSDN 皮肤主题:创作都市 设计师:CSDN官方博客 返回首页

打赏作者

MomentNi

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值