![4f5fa758bf343ecc2cf2671833213355.png](https://img-blog.csdnimg.cn/img_convert/4f5fa758bf343ecc2cf2671833213355.png)
近段时间学习使用pytorch搭建神经网络,用kaggle竞赛入门题目《泰坦尼克号生存预测》进行练习,特征工程参考目前网上的教程,希望有大佬进行指导。
不足:
- 怎么查看分类变量与连续变量的相关性,如何选取特征字段
- 模型每层的参数怎么定义
数据字段:
![68b64363b3fc2f5620e40029c6d373cb.png](https://img-blog.csdnimg.cn/img_convert/68b64363b3fc2f5620e40029c6d373cb.png)
读取数据
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
train = pd.read_csv('data/Titanic/train.csv') #训练集
test = pd.read_csv('data/Titanic/test.csv') # 测试集
print('训练集:', train.shape, '测试集:', test.shape)
# 合并训练集和测试集
total_data = train.append(test, sort=False, ignore_index=True)
训练集: (891, 12) 测试集: (418, 11)
total_data.head()
![9be3fca3edfa2b284d754c72f114264f.png](https://img-blog.csdnimg.cn/img_convert/9be3fca3edfa2b284d754c72f114264f.png)
# 查看数据摘要信息
train.info()
print("-" * 40)
test.info()
![d022496cce2eaaf1281ef02d8d6b2834.png](https://img-blog.csdnimg.cn/img_convert/d022496cce2eaaf1281ef02d8d6b2834.png)
训练集的Age、Cabin、Embarked存在丢失数据;测试集的Age、Fare、Cabin存在丢失数据
特征工程
# 查看生存比例
total_data['Survived'].value_counts().plot.pie(autopct='%1.2f%%')
![d28b7b3d7bedf65fa7364a0d6ba4c135.png](https://img-blog.csdnimg.cn/img_convert/d28b7b3d7bedf65fa7364a0d6ba4c135.png)
存活率36.38%,死亡率63.62%
# 查看不同性别的人员存活率
print(total_data.groupby(['Sex'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Sex', hue='Survived', data=total_data)
plt.title('Sex and Survived')
![3f123e7f3c0603b46cc7db93ad463087.png](https://img-blog.csdnimg.cn/img_convert/3f123e7f3c0603b46cc7db93ad463087.png)
登船人数男性占比64.75%,但女性的存活机率74.2%远高于男性存活率18.8%
# 查看Embarked列值分布
total_data['Embarked'].value_counts()
![ce982635bf22dde0db2f51f2cce7da02.png](https://img-blog.csdnimg.cn/img_convert/ce982635bf22dde0db2f51f2cce7da02.png)
# 用众数填充Embarked空值
total_data['Embarked'].fillna(
total_data.Embarked.mode().values[0], inplace=True)
# 查看不同上船地人员的存活率
print(total_data.groupby(['Embarked'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Embarked', hue='Survived', data=total_data)
plt.title('Embarked and Survived')
![b0c970cb9ac57f569863296a53f27cd6.png](https://img-blog.csdnimg.cn/img_convert/b0c970cb9ac57f569863296a53f27cd6.png)
C地登船的存活率最高、其次为Q地登船、S地登船人数最多但存活率仅有1/3
# Cabin缺失比较多,用Unknown替代缺失值
total_data['Cabin'].fillna('U', inplace=True)
total_data['Cabin'] = total_data['Cabin'].map(
lambda x: re.compile('([a-zA-Z]+)').search(x).group())
print(total_data.groupby(['Cabin'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Cabin', hue='Survived', data=total_data)
plt.title('Cabin and Survived')
![0545e75b2cb1d1095b31ea2f199b5eea.png](https://img-blog.csdnimg.cn/img_convert/0545e75b2cb1d1095b31ea2f199b5eea.png)
船舱票无信息的群体占77%,存活率仅0.3;船舱票B/D/E存活率较高均超70%
# 不同票等级生存的分布
print(total_data.groupby(['Pclass'])['Survived'].agg(['count', 'mean']))
# 不同票等级生存的几率
plt.figure(figsize=(10, 5))
sns.countplot(x='Pclass', hue='Survived', data=total_data)
plt.title('Pclass and Survived')
![a779ae59dc05b59463b9c7b535b060cd.png](https://img-blog.csdnimg.cn/img_convert/a779ae59dc05b59463b9c7b535b060cd.png)
票等级越高存活率就越高;3等级的人数占比超50%,但存活率不到1/3
# 不同仓位男女生存的几率
print(total_data[['Sex', 'Pclass', 'Survived']].groupby(
['Pclass', 'Sex']).agg(['count', 'mean']))
total_data[['Sex', 'Pclass', 'Survived']].groupby(
['Pclass', 'Sex']).mean().plot.bar(figsize=(10, 5))
plt.xticks(rotation=0)
plt.title('Sex, Pclass and Survived')