![4f5fa758bf343ecc2cf2671833213355.png](https://i-blog.csdnimg.cn/blog_migrate/e98452085edbaf5d37972399249268b0.jpeg)
近段时间学习使用pytorch搭建神经网络,用kaggle竞赛入门题目《泰坦尼克号生存预测》进行练习,特征工程参考目前网上的教程,希望有大佬进行指导。
不足:
- 怎么查看分类变量与连续变量的相关性,如何选取特征字段
- 模型每层的参数怎么定义
数据字段:
![68b64363b3fc2f5620e40029c6d373cb.png](https://i-blog.csdnimg.cn/blog_migrate/607c96fdd354c5c037d90c6897fb9100.jpeg)
读取数据
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
train = pd.read_csv('data/Titanic/train.csv') #训练集
test = pd.read_csv('data/Titanic/test.csv') # 测试集
print('训练集:', train.shape, '测试集:', test.shape)
# 合并训练集和测试集
total_data = train.append(test, sort=False, ignore_index=True)
训练集: (891, 12) 测试集: (418, 11)
total_data.head()
![9be3fca3edfa2b284d754c72f114264f.png](https://i-blog.csdnimg.cn/blog_migrate/5978462168addd4c60cb5ba0f7287c49.jpeg)
# 查看数据摘要信息
train.info()
print("-" * 40)
test.info()
![d022496cce2eaaf1281ef02d8d6b2834.png](https://i-blog.csdnimg.cn/blog_migrate/d75c9cd6dcc64c5a967b8b2812dc6cdb.jpeg)
训练集的Age、Cabin、Embarked存在丢失数据;测试集的Age、Fare、Cabin存在丢失数据
特征工程
# 查看生存比例
total_data['Survived'].value_counts().plot.pie(autopct='%1.2f%%')
![d28b7b3d7bedf65fa7364a0d6ba4c135.png](https://i-blog.csdnimg.cn/blog_migrate/d2f87b84f75533d484b534b5299ced75.png)
存活率36.38%,死亡率63.62%
# 查看不同性别的人员存活率
print(total_data.groupby(['Sex'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Sex', hue='Survived', data=total_data)
plt.title('Sex and Survived')
![3f123e7f3c0603b46cc7db93ad463087.png](https://i-blog.csdnimg.cn/blog_migrate/4cbd3c8352d2d1c3b3936395b0f1fe07.jpeg)
登船人数男性占比64.75%,但女性的存活机率74.2%远高于男性存活率18.8%
# 查看Embarked列值分布
total_data['Embarked'].value_counts()
![ce982635bf22dde0db2f51f2cce7da02.png](https://i-blog.csdnimg.cn/blog_migrate/5656701b9b5c578bf8e3443e171567ac.png)
# 用众数填充Embarked空值
total_data['Embarked'].fillna(
total_data.Embarked.mode().values[0], inplace=True)
# 查看不同上船地人员的存活率
print(total_data.groupby(['Embarked'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Embarked', hue='Survived', data=total_data)
plt.title('Embarked and Survived')
![b0c970cb9ac57f569863296a53f27cd6.png](https://i-blog.csdnimg.cn/blog_migrate/0b012173a685c07ce7f0064a1a728e49.jpeg)
C地登船的存活率最高、其次为Q地登船、S地登船人数最多但存活率仅有1/3
# Cabin缺失比较多,用Unknown替代缺失值
total_data['Cabin'].fillna('U', inplace=True)
total_data['Cabin'] = total_data['Cabin'].map(
lambda x: re.compile('([a-zA-Z]+)').search(x).group())
print(total_data.groupby(['Cabin'])['Survived'].agg(['count', 'mean']))
plt.figure(figsize=(10, 5))
sns.countplot(x='Cabin', hue='Survived', data=total_data)
plt.title('Cabin and Survived')
![0545e75b2cb1d1095b31ea2f199b5eea.png](https://i-blog.csdnimg.cn/blog_migrate/2ad996d5073bc687d4bfb12a450c0d6f.jpeg)
船舱票无信息的群体占77%,存活率仅0.3;船舱票B/D/E存活率较高均超70%
# 不同票等级生存的分布
print(total_data.groupby(['Pclass'])['Survived'].agg(['count', 'mean']))
# 不同票等级生存的几率
plt.figure(figsize=(10, 5))
sns.countplot(x='Pclass', hue='Survived', data=total_data)
plt.title('Pclass and Survived')
![a779ae59dc05b59463b9c7b535b060cd.png](https://i-blog.csdnimg.cn/blog_migrate/27bb67e285b96ba79313b65fbc284820.jpeg)
票等级越高存活率就越高;3等级的人数占比超50%,但存活率不到1/3
# 不同仓位男女生存的几率
print(total_data[['Sex', 'Pclass', 'Survived']].groupby(
['Pclass', 'Sex']).agg(['count', 'mean']))
total_data[['Sex', 'Pclass', 'Survived']].groupby(
['Pclass', 'Sex']).mean().plot.bar(figsize=(10, 5))
plt.xticks(rotation=0)
plt.title('Sex, Pclass and Survived')