Kaggle Titanic 生存问题 - Java 机器学习实战

Kaggle Titanic 生存问题 - Java 机器学习实战

1. Kaggle Titanic 介绍

1.1 Kaggle 竞赛

Kaggle 是目前全球最有影响力的机器学习竞赛平台。

企业或者研究者可以将数据、问题描述、期望的指标发布到Kaggle上,以竞赛的形式向广大的数据科学家征集解决方案,类似于KDD-CUP(国际知识发现和数据挖掘竞赛)。Kaggle上的参赛者将数据下载下来,分析数据,然后运用机器学习、深度学习、数据挖掘等知识,建立算法模型,解决问题得出结果,最后将结果提交,如果提交的结果符合指标要求(比如logloss,mse等)并且在参赛者中排名第一,将获得比赛丰厚的奖金。

Titanic 竞赛是 Kaggle 提供的一个机器学习入门级比赛。主要是给新手练手用的(当然也没有奖金)。目前已经有将近1万只队伍参加过这个比赛了。

下面我们来具体看看 Titanic 问题。

1.2 Titanic 问题

我们首先看看 Competition Description:

The sinking of the RMS Titanic is one of the most infamous shipwrecks in history. On April 15, 1912, during her maiden voyage, the Titanic sank after colliding with an iceberg, killing 1502 out of 2224 passengers and crew. This sensational tragedy shocked the international community and led to better safety regulations for ships.

One of the reasons that the shipwreck led to such loss of life was that there were not enough lifeboats for the passengers and crew. Although there was some element of luck involved in surviving the sinking, some groups of people were more likely to survive than others, such as women, children, and the upper-class.

In this challenge, we ask you to complete the analysis of what sorts of people were likely to survive. In particular, we ask you to apply the tools of machine learning to predict which passengers survived the tragedy.

1912 年4月15日的泰坦尼克沉船事件杀死了 2224 名船员和乘客中的 1502 人。很多人认为救生艇的缺乏是造成大量人员伤亡的一个重要原因。

数据显示,一些群体在沉船事件中有更大的幸存可能。比如:妇女、儿童或者上层人士。

这个比赛的目的就是根据乘客的特征来预测其是否能够幸存。

kaggle 提供的数据主要分为两个文件: train.csv 和 test.csv。其中 train.csv 中的数据携带了是否幸存的信息,而 test.csv 则没有携带这一信息。我们将使用 train.csv 中的数据来训练模型,并在 test.csv 上执行预测。预测结果会提交给 kaggle 用于评价模型。

2. 数据探索

2.1 数据初探

首先我们看看 train.csv 文件的内容:

这里写图片描述

这里一共有 12 个字段。其中 PassengerId 仅仅是用作编号,没有其他用处。其他字段我们一一简要进行说明.

  • Survived:该乘客是否幸存. 1表示幸存,0表示未能幸存

  • Pclass: 表示该乘客的社会等级. 取值为1,2,3

  • Name:名字

  • Sex: 性别,分别取 male 或 female

  • Age: 年龄

  • SibSp: 船上有多少个兄弟姐妹

  • Parch: 船上有多少个长辈.

  • Ticket: 船票号码

  • Fare: 票价

  • Cabin: 客舱位置编号

  • Embarked: 登船的码头. 有S、Q 和 C

2.2 特征分析

在这一部分我们将逐一分析特征数据,并探索其同存活率之间的关系。

由于笔者对 Java 更熟悉一些,因此将使用 Java 语言,用 Spark 进行分析。Spark 下的 SQL 可以将数据当做数据库表一样使用 sql 来进行处理。这对熟悉 sql 的人来说无疑是一种福利~~~

首先进行初始化:


SparkSession spark = SparkSession
        .builder()
        .appName("Java Spark SQL basic example")
        .config("spark.some.config.option", "some-value")
        .getOrCreate();

Dataset<Row> titanicDFCsv = spark.read().format("csv")
        .option("sep", ",")
        .option("inferSchema", "true")
        .option("header", "true")
        .load("data/train.csv");  //数据保存在 data 文件夹下

titanicDFCsv.createOrReplaceTempView("person"); //这里创建 spark 的 person 表,方便后面用 spark sql 分析

初始化好了,我们就来看看具体的特征吧。

2.2.1 Pclass

首先看看阶层数据.


Dataset<Row> s = spark.sql("select Pclass, Survived, count(*) from person group by Pclass, Survived order by Pclass asc, Survived asc");
System.out.println("阶层存活统计");
s.show();


ChartUtil.createStackedBarChart(s, "pclass.jpg");

其中画图部分的代码如下:


public static void createStackedBarChart(Dataset<Row> df, String path){
    final DefaultCategoryDataset dataset = new DefaultCategoryDataset( );
    for(Row row: df.collectAsList()){
        if(row.get(0) instanceof Integer) {
            dataset.addValue((Long) row.get(2), (Integer) row.get(1), (Integer)row.get(0));
        } else if(row.get(0) instanceof String){
            dataset.addValue((Long) row.get(2), (Integer) row.get(1), row.getString(0));
        }
    }

    String[] columns = df.columns();

    StandardChartTheme standardChartTheme = new StandardChartTheme("CN");
    standardChartTheme.setExtraLargeFont(new Font("隶书", Font.BOLD, 20));
    standardChartTheme.setRegularFont(new Font("宋书", Font.PLAIN, 15));
    standardChartTheme.setLargeFont(new Font("宋书", Font.PLAIN, 15));
    ChartFactory.setChartTheme(standardChartTheme);

    JFreeChart barChart = ChartFactory.createStackedBarChart(
            "幸存关联分析",
            columns[0], "人数",
            dataset,PlotOrientation.VERTICAL,
            true, true, false);

    int width = 640; /* Width of the image */
    int height = 480; /* Height of the image */
    File BarChart = new File(path);
    try {
        ChartUtilities.saveChartAsJPEG( BarChart , barChart , width , height );
    } catch (IOException e) {
        e.printStackTrace();
    }
}

得到如下结果:


阶层存活统计

+------+--------+--------+

|Pclass|Survived|count(1)|

+------+--------+--------+

| 1| 0| 80|

| 1| 1| 136|

| 2| 0| 97|

| 2| 1| 87|

| 3| 0| 372|

| 3| 1| 119|

+------+--------+--------+

很明显可以看出,第一阶层的人存活率最高。存活136个,占比 63%;第二阶层占比47%;第三阶层占比24%。

更直观的用图表看看,结果更加明显:

这里写图片描述

2.2.2 Sex

再看看性别的影响:


s = spark.sql("select Sex, Survived, count(*) from person group by Sex, Survived order by Sex asc, Survived asc");
System.out.println("性别存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "sex.jpg");

结果如下:
这里写图片描述

女性的存活率要比男性高得多。

2.2.3 Age

看看年龄的影响。由于年龄比较多,我们将其进行简单的分组。5岁以下,5-12岁,12-20岁,20-30岁,30-40岁,40-50岁,50-60岁,60岁以上分别分到一组。


String sql = "select " +
        "case " +
        "when Age <= 5 then 1 " +
        "when Age >5 and Age <= 12 then 2 " +
        "when Age > 12 and Age <= 20 then 3 " +
        "when Age > 20 and Age <= 30 then 4 " +
        "when Age > 30 and Age <= 40 then 5 " +
        "when Age > 40 and Age <= 50 then 6 " +
        "when Age > 50 and Age <= 60 then 7 " +
        "when Age > 60 then 8 " +
        "end " +
        "as ageGroup, Survived, count(*) from person group by ageGroup, Survived order by ageGroup, Survived";
s = spark.sql(sql);
System.out.println("年龄分组存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "age.jpg");

得到结果如下:

这里写图片描述

结果不是特别明显,能看出的是小孩子幸存率比较高,而老人低一些。

2.2.4 SibSp

船上的兄弟数目有影响吗?我们来看看


s = spark.sql("select SibSp, Survived, count(*) from person group by SibSp, Survived order by SibSp asc, Survived");
System.out.println("兄弟数存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "sibsp.jpg");

结果如下:
这里写图片描述

有兄弟在船上的人,幸存率币没有的稍微高一些。另外数据还可以看出,多于一个兄弟在船上的人并不多。

2.2.5 Parch

是否有父母在船上会有影响吗?


s = spark.sql("select Parch, Survived, count(*) from person group by Parch, Survived order by Parch asc, Survived");
System.out.println("长辈数存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "parch.jpg");

结果如下:
这里写图片描述

好像有父母在船上要高一些。

2.2.6 Embark

出发站会有影响吗?


s = spark.sql("select Embarked, Survived, count(*) from person group by Embarked, Survived order by Embarked, Survived");
System.out.println("出发站存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "embark.jpg");

看看结果:

这里写图片描述

貌似在 C 站上传的存活率高一些呢~~

2.2.7 Fare

票价是否有影响呢?由于票价也比较分散,我们同样进行分组处理


sql = "select " +
        "case " +
        "when Fare <= 10 then 1 " +
        "when Fare > 10 and Fare <= 20 then 2 " +
        "when Fare > 20 and Fare <= 30 then 3 " +
        "when Fare > 30 and Fare <= 60 then 4 " +
        "when Fare > 60 and Fare <= 100 then 5 " +
        "else 6 " +
        "end " +
        "as fareGroup, Survived, count(*) from person group by fareGroup, Survived order by fareGroup asc, Survived";
s = spark.sql(sql);
System.out.println("船票存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "fare.jpg");

结果如下:
这里写图片描述

这一组的结果也比较明显:票价越贵存活率越高!

2.2.8 Cabin

船舱是否有影响呢?


s = spark.sql("select LEFT(Cabin, 1) as ca, Survived, count(*) from person group by ca, Survived order by ca, Survived");
System.out.println("cabin 存活统计");
s.show();
ChartUtil.createStackedBarChart(s, "cabin.jpg");

看结果:

这里写图片描述

2.2.9 Name & Ticket

名字和票号这两个字段看不出什么特征,第一感觉是暂时丢弃不做处理。

当然,如果详细分析的话可能可以从名字提取出国籍或者性别之类的,我们暂时不作考虑。

同样,票号也可以根据是否是纯数字、以及数字长度做一个简单的切分。不过暂时也不考虑。

下面,我们将使用这些特征进行模型设计。首先看看数据预处理。

3. 数据预处理

3.1 缺失值处理

从数据可以看出,Age 和 Cabin 这两个字段有大量的缺失数据。我们分别需要对它们进行一定的处理。

  • Cabin

针对 Cabin,暂时只将其转换为表示存在的1 和表示不存在的0. 我们使用 spark 的 dataset api 直接进行转换


titanicDFCsv = titanicDFCsv.withColumn("Cabin", 
        functions.when(functions.col("Cabin").isNull(), 0)
                .otherwise(1));
titanicDFCsv.show();
  • Age

针对年龄数据,可以使用简单的均值填充方案进行处理。我们将使用 spark 中的 Imputer 进行填充操作,转换器如下:


Imputer imputer = new Imputer()
        .setInputCols(new String[]{"Age"})
        .setOutputCols(new String[]{"ageNew"});
3.2 格式转换
  • Sex

性别是用字符串 ‘male’ 和 ‘female’ 表示的,无法直接输入模型。将其转换为数值类型:


titanicDFCsv = titanicDFCsv.withColumn("Sex",
        functions.when(functions.col("Sex").equalTo("male"), 1)
                .otherwise(0));
  • Embark

登船站 Embark 是 category 类型的,因此可以转换为 one_hot 编码的格式。由于category 是用字符串表示的,因此需要先转换为数值类型,再转换为 one_hot 格式:


StringIndexer indexer = new StringIndexer()
        .setInputCol("Embarked")
        .setOutputCol("EmbarkedIndex");
indexer.setHandleInvalid("keep");

// 转换成稀疏矩阵
OneHotEncoderEstimator encoder = new OneHotEncoderEstimator()
        .setInputCols(new String[] {"EmbarkedIndex"})
        .setOutputCols(new String[] {"EmbarkedVec"});

3.3 数据归一化

Age 和 Fare 这两个字段的取值范围比较大。我们知道,若使用梯度下降算法,未归一化的数据可能造成算法收敛困难。因此需要对这两个字段进行归一化。

在归一化之前,需要先将所选字段组合成一个 Vector. 字段为 features


VectorAssembler assembler = new VectorAssembler()
        .setInputCols(new String[]{"Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Cabin", "EmbarkedVec"})
        .setOutputCol("features");

spark 提供的归一化工具支持一个 Vector 上进行归一化,相当于对 Vector 上的所有字段归一化。将 features 转换为 scaledFeatures.




StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithStd(true)
.setWithMean(false);

4. 逻辑回归、参数调优

4.1 基本的逻辑回归

我们首先使用逻辑回归来训练模型


LogisticRegression lr = new LogisticRegression()
        .setMaxIter(100)
        .setRegParam(0.1)
        .setFeaturesCol("scaledFeatures")
        .setLabelCol("Survived")
        .setElasticNetParam(0);

模型主要有以下参数:

  • maxIter 最大迭代次数

  • regParam 规范化项的权重

  • featuresCol 训练的特征列. 这一列本身又是个矩阵,所有训练模型需要的特征都在里面

  • labelCol 训练的预测列. 存放实际的 y 值.

  • elasticNetParam 这个是 elastic net 的参数,介于0-1之间. 表示 L1和L2正则之间的比例. 0 表示 L2 正则,1表示 L1 正则. 即下面公式中的 α 值:

αw1+(1α)12w22 α ‖ w ‖ 1 + ( 1 − α ) 1 2 ‖ w ‖ 2 2

4.2 使用 CV 进行模型选择

调参,每一组参数代表一个模型

CV (Cross-Validation)即交叉验证。

spark 提供了 CrossValidator 类来协助我们进行交叉验证,选择最合适的模型。其工作原理伪码如下:


将训练数据集划分为 k 份 [t1,t2,..., tk]

对每个模型 m:

  在 k 份数据上训练模型 k 次,模型的评价指标 E(m) 为每次训练评价指标的均值

选择评价指标最高的 E(m) 对应的模型 m

如果我们有 n 个模型,将数据切分成 k 分,则交叉验证需要训练模型 n*k 次。在实际使用中需要考虑到这一点。

下面是我们的实现:




ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[] {0.3, 0.1, 0.03, 0.01})
.addGrid(lr.elasticNetParam(), new double[] {0, 0.3, 0.6, 1})
.build();

CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC")
.setLabelCol("Survived")
.setRawPredictionCol("prediction")) //使用二分评价指标
.setEstimatorParamMaps(paramGrid)
.setNumFolds(3)
.setParallelism(2); // Evaluate up to 2 parameter settings in parallel

// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = cv.fit(titanicDFCsv);
PipelineModel bestPipelineMode = (PipelineModel)cvModel.bestModel();
LogisticRegressionModel bestLrModel = (LogisticRegressionModel)bestPipelineMode.stages()[5];
double bestRegParam = bestLrModel.getRegParam();
double bestElasticNetParam = bestLrModel.getElasticNetParam();
System.out.println("bestRegParam : " + bestRegParam);
System.out.println("bestElasticNetParam:" + bestElasticNetParam);

//LogisticRegressionModel lrModel = (LogisticRegressionModel)model.stages()[5];
BinaryLogisticRegressionTrainingSummary summary = bestLrModel.binarySummary();
double precision = summary.weightedPrecision();
double recall = summary.weightedRecall();
double accuracy = summary.accuracy();

System.out.println("Precision: " + precision);
System.out.println("Recall: " + recall);
System.out.println("Accuracy: " + accuracy);

得到如下结果:


bestRegParam : 0.01

bestElasticNetParam:0.0

Precision: 0.7975373034403689

Recall: 0.7991021324354658

Accuracy: 0.7991021324354658

结果显示,最好的 lr 模型 regParam 为0.01,elasticNetParam 为 0,即使用 L2 正则. 模型的准确率达到 79.9。

5. 输出结果

输出结果代码如下:


Dataset<Row> testDF = spark.read().format("csv")
        .option("sep", ",")
        .option("inferSchema", "true")
        .option("header", "true")
        .load("data/test.csv");
testDF = testDF.withColumn("Cabin",
        functions.when(functions.col("Cabin").isNull(), 0)
                .otherwise(1));
testDF = testDF.withColumn("Sex",
        functions.when(functions.col("Sex").equalTo("male"), 1)
                .otherwise(0));
testDF = testDF.withColumn("Fare",
        functions.when(functions.col("Fare").isNull(), 0)
                .otherwise(functions.col("Fare")));
testDF.show();

Dataset<Row> predictions = pipelineModel.transform(testDF);
for(String s: predictions.columns()){
    System.out.println(s);
}

CsvWriter csvWriter = new CsvWriter("data/submission.csv", ',', Charset.forName("UTF-8"));
String[] csvHeaders = { "PassengerId", "Survived"};
csvWriter.writeRecord(csvHeaders);
for (Row r : predictions.select("PassengerId", "prediction").collectAsList()) {
    System.out.println(r.get(0) + "->" + r.get(1));
    String[] csvContent = {(r.get(0)).toString(), String.valueOf(((Double)r.get(1)).intValue())};
    csvWriter.writeRecord(csvContent);
}
csvWriter.close();

我们将输出的 submission.csv 文件提交,得到如下结果:

这里写图片描述

名次比较低啊~~。不过怎么也算首次提交,只使用了最基本的 LR 模型,还算不错了。后面可以慢慢优化。

本来打算好好折腾一下 java 机器学习的。但是一遍流程走下来,后面不准备用 java spark 搞了,很多 python 里面现成的工具 java 里面都没有。后面切换 python~~~

6. 参考资料

https://blog.csdn.net/g11d111/article/details/77164074

https://www.kaggle.com/

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
根据引用的分析,在进行Kaggle Titanic生存预测时,可以将乘客的年龄、性别和船票等级作为重点数据进行分析。乘客的家属数量可能对生存率有影响,但需要与其他信息一起探讨,例如乘客是否是船上所有家属中最年轻的一个。船票编号、价格和客舱号可以视为一类信息,与船票等级可能有关联。而最后一个登船港口对生存率的影响暂时被假设为无关,因为仅凭几百条数据很难确定其影响。 引用和可以看到,登船地点对生存率有一定的影响。在S港口登船的人数最多,C港口次之,Q港口最少。在S港口登船的乘客生存率较低,而在C港口登船的乘客生存率较高。另外,在不同登船港口的女性乘客占比也不同,C港口的女性乘客数量最多,Q港口次之,S港口最少。由于前面已经了解到女性的生存率明显高于男性,因此性别因素可能是导致生存率差异的原因之一。 综上所述,在Kaggle Titanic生存预测中,可以将乘客的年龄、性别、船票等级和登船地点作为重要的特征进行分析。其中,女性的生存率较高,C港口登船的乘客生存率较高,而S港口登船的乘客生存率较低。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [kaggle|泰坦尼克号生存预测](https://blog.csdn.net/weixin_45435206/article/details/104422277)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [【机器学习kaggle赛事】泰坦尼克号生存预测](https://blog.csdn.net/m0_51933492/article/details/126895547)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值