机器学习实战 美国50K工资分类


参考:
数据处理 https://cloud.tencent.com/developer/article/1338337
数据分析 https://zhuanlan.zhihu.com/p/297955188
多种方法使用 https://blog.csdn.net/weixin_37379106/article/details/103569653

1. 导入数据

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False #用来正常显示负号
df = pd.read_csv('adults.csv')
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32561 entries, 0 to 32560
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             32561 non-null  int64 
 1   workclass       32561 non-null  object
 2   final_weight    32561 non-null  int64 
 3   education       32561 non-null  object
 4   education_num   32561 non-null  int64 
 5   marital_status  32561 non-null  object
 6   occupation      32561 non-null  object
 7   relationship    32561 non-null  object
 8   race            32561 non-null  object
 9   sex             32561 non-null  object
 10  capital_gain    32561 non-null  int64 
 11  capital_loss    32561 non-null  int64 
 12  hours_per_week  32561 non-null  int64 
 13  native_country  32561 non-null  object
 14  salary          32561 non-null  object
dtypes: int64(6), object(9)
memory usage: 3.7+ MB
df.head()
ageworkclassfinal_weighteducationeducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
df.workclass.value_counts()
Private             22696
Self-emp-not-inc     2541
Local-gov            2093
?                    1836
State-gov            1298
Self-emp-inc         1116
Federal-gov           960
Without-pay            14
Never-worked            7
Name: workclass, dtype: int64
df.salary.value_counts()
<=50K    24720
>50K      7841
Name: salary, dtype: int64

2. 处理数据

2.1缺失值处理

# 缺失值都是用  ? 替换的,首先将其移除:
clean_df = df
clean_df = clean_df.replace('?',np.nan).dropna()
df = df.replace('?',np.nan)
clean_df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 30162 entries, 0 to 32560
Data columns (total 15 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   age             30162 non-null  int64 
 1   workclass       30162 non-null  object
 2   final_weight    30162 non-null  int64 
 3   education       30162 non-null  object
 4   education_num   30162 non-null  int64 
 5   marital_status  30162 non-null  object
 6   occupation      30162 non-null  object
 7   relationship    30162 non-null  object
 8   race            30162 non-null  object
 9   sex             30162 non-null  object
 10  capital_gain    30162 non-null  int64 
 11  capital_loss    30162 non-null  int64 
 12  hours_per_week  30162 non-null  int64 
 13  native_country  30162 non-null  object
 14  salary          30162 non-null  object
dtypes: int64(6), object(9)
memory usage: 3.7+ MB

2.2 探索数据(对df进行数据探索,df为将’?’ 替换为np.nan 的数据集,clean_df 为删除了’?'所在行的数据集)

import seaborn as sns
%matplotlib inline
#检查数据相关性
sns.heatmap(df.corr()
           ,annot=True
           ,center=0
           ,linewidth=0.8)
<AxesSubplot:>


在这里插入图片描述

直接检查相关性似乎并不妥当,因为数据并不都是数值类型

# 年龄分布
sns.set_style('whitegrid')
plt.subplots(figsize=(15,8))
s = df['age'].value_counts()
sns.barplot(s.index,s.values)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(





<AxesSubplot:>


在这里插入图片描述

df.age.mean() # 平均年龄
38.58164675532078

# 第一,年龄和工资的关系
s=df['age'].value_counts()
k=df['age'][df['salary']=='>50K'].value_counts()
sns.set_style("whitegrid")
f, ax = plt.subplots(figsize=(18, 9))
sns.set_color_codes("pastel")
sns.barplot(s.index,s.values,label='total',color="b")
sns.barplot(k.index,k.values,label='income>50K',color="g")
ax.legend(ncol=2, loc="upper left", frameon=True)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(





<matplotlib.legend.Legend at 0x208a75380a0>


在这里插入图片描述

# 第二,教育水平# 高中毕业人数有将近1.6万人,其次是大学肄业plt.subplots(figsize=(15,6))s = df['education'].value_counts()sns.barplot(s.index,s.values)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(





<AxesSubplot:>


在这里插入图片描述

df['education_num'].value_counts()
9     1050110     729113     535514     172311     13827      117512     10676       9334       64615      5765       5148       43316      4133       3332       1681        51Name: education_num, dtype: int64
edu_n = df.groupby(['education','education_num'])['education'].count()edu_n
education     education_num10th          6                  93311th          7                 117512th          8                  4331st-4th       2                  1685th-6th       3                  3337th-8th       4                  6469th           5                  514Assoc-acdm    12                1067Assoc-voc     11                1382Bachelors     13                5355Doctorate     16                 413HS-grad       9                10501Masters       14                1723Preschool     1                   51Prof-school   15                 576Some-college  10                7291Name: education, dtype: int64

education 和 educational-num 是一一对应关系,做算法模型时可以删除一列。

education是分类变量,educational-num是数值变量,受教育水平是有顺序的,数字大小是有比较意义的,因此保留education-num。

s
HS-grad         10501Some-college     7291Bachelors        5355Masters          1723Assoc-voc        138211th             1175Assoc-acdm       106710th              9337th-8th           646Prof-school       5769th               51412th              433Doctorate         4135th-6th           3331st-4th           168Preschool          51Name: education, dtype: int64
edu_hsalary = df['education'][df['salary']=='>50K'].value_counts()edu_lsalary = df['education'][df['salary']=='<=50K'].value_counts()edu_high_percent = edu_hsalary/sedu_low_percent = edu_lsalary/sfig = plt.figure(figsize=(15,6))sns.barplot(edu_high_percent.index,edu_high_percent.values,color='red',label='edu_high_salary')sns.barplot(edu_low_percent.index,edu_low_percent.values,bottom = edu_high_percent,color='yellow',label='edu_low_salary')fig.legend(ncol=2,loc='upper center',frameon=True)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(





<matplotlib.legend.Legend at 0x208a73f1ca0>


在这里插入图片描述

可以看到超过50K占比最高依次是,Prof-school,Doctorate,Masters,Bachelors 跟学习年限呈现正相关性。

# 第三,婚姻状况f,ax = plt.subplots(figsize=(11,6))s=df['marital_status'].value_counts()sns.barplot(s.index,s.values)ax.set_xticklabels(ax.get_xticklabels(),rotation=45)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(





[Text(0, 0, 'Married-civ-spouse'), Text(1, 0, 'Never-married'), Text(2, 0, 'Divorced'), Text(3, 0, 'Separated'), Text(4, 0, 'Widowed'), Text(5, 0, 'Married-spouse-absent'), Text(6, 0, 'Married-AF-spouse')]


在这里插入图片描述

mer_hsalary=df['marital_status'][df['salary']=='>50K'].value_counts()mer_lsalary=df['marital_status'][df['salary']=='<=50K'].value_counts()f,ax = plt.subplots(figsize=[12,6])sns.barplot(mer_hsalary.index,mer_hsalary.values,color='blue',alpha = 0.7,label='mer_low_salary')sns.barplot(mer_lsalary.index,mer_lsalary.values,color='yellow',alpha = 0.5,label='mer_high_salary')ax.legend(ncol=2,loc='upper center',frameon=True)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(





<matplotlib.legend.Legend at 0x208a7e41eb0>


在这里插入图片描述

可以看到已婚和收入高低有正相关性,但还不能说明因此有因果关系。

# 第四、职业分布以及和收入关系plt.subplots(figsize=(8,8))s=df['occupation'].value_counts()sns.barplot(y=s.index,x=s.values)
<AxesSubplot:>


在这里插入图片描述

从业人数前三分别是Prof-specialty,Craft-repair,Exec-managerial

occ_hsalary=df['occupation'][df['salary']=='>50K'].value_counts()occ_lsalary=df['occupation'][df['salary']=='<=50K'].value_counts()f,ax=plt.subplots(figsize=[12,6])a=s.indexsns.barplot(y=occ_hsalary.index,x=occ_hsalary.values,color='yellow',order=a,alpha = 0.7,label='occ_low_salary')sns.barplot(y=occ_lsalary.index,x=occ_lsalary.values,color='blue',order=a,alpha=0.5,label='occ_high_salary')ax.legend(ncol=2, loc="lower center", frameon=True)
<matplotlib.legend.Legend at 0x208a841b400>


在这里插入图片描述

可以看出高收入占比比较高的是Exec-managerial、Prof-specialty,比较低的是Handlers-cleaners、Farming-fishing,比较符合我们的日常认知。

# 5 家庭plt.subplots(figsize=(8,8))s=df['relationship'].value_counts()plt.pie(s.values,labels=s.index,autopct='%1.1f%%')
([<matplotlib.patches.Wedge at 0x208a87cb4c0>,  <matplotlib.patches.Wedge at 0x208a87cbc10>,  <matplotlib.patches.Wedge at 0x208a87d72e0>,  <matplotlib.patches.Wedge at 0x208a87d7970>,  <matplotlib.patches.Wedge at 0x208a87d7fa0>,  <matplotlib.patches.Wedge at 0x208a87e66d0>], [Text(0.3228564400293403, 1.0515530034817937, 'Husband'),  Text(-1.0768528506930521, -0.22447257728784475, 'Not-in-family'),  Text(-0.08243982298868217, -1.0969064114980798, 'Own-child'),  Text(0.7469351484545255, -0.8075195873805269, 'Unmarried'),  Text(1.0368150199444963, -0.3674433485824635, 'Wife'),  Text(1.0950764363882008, -0.10395960016909805, 'Other-relative')], [Text(0.1761035127432765, 0.5735743655355238, '40.5%'),  Text(-0.5873742821962102, -0.12243958761155167, '25.5%'),  Text(-0.04496717617564482, -0.5983125880898617, '15.6%'),  Text(0.40741917188428656, -0.4404652294802873, '10.6%'),  Text(0.5655354654242707, -0.20042364468134374, '4.8%'),  Text(0.5973144198481095, -0.05670523645587165, '3.0%')])


在这里插入图片描述

# 种族plt.subplots(figsize=(8,8))s=df['race'].value_counts()plt.pie(s.values,labels=s.index,autopct='%1.1f%%')
([<matplotlib.patches.Wedge at 0x208a883a7f0>,  <matplotlib.patches.Wedge at 0x208a883af10>,  <matplotlib.patches.Wedge at 0x208a884a5e0>,  <matplotlib.patches.Wedge at 0x208a884ac70>,  <matplotlib.patches.Wedge at 0x208a8856370>], [Text(-0.9867232454903924, 0.486186421868101, 'White'),  Text(0.898950111132619, -0.6339469202501515, 'Black'),  Text(1.0752451188481307, -0.23205157701095105, 'Asian-Pac-Islander'),  Text(1.0962767570479313, -0.0904282696753132, 'Amer-Indian-Eskimo'),  Text(1.0996240020243266, -0.028758549546243486, 'Other')], [Text(-0.5382126793583958, 0.2651925937462369, '85.4%'),  Text(0.4903364242541558, -0.3457892292273554, '9.6%'),  Text(0.5864973375535257, -0.12657358746051875, '3.2%'),  Text(0.5979691402079624, -0.04932451073198901, '1.0%'),  Text(0.5997949101950872, -0.015686481570678264, '0.8%')])


在这里插入图片描述

# 6 性别df['sex'].value_counts()
Male      21790Female    10771Name: sex, dtype: int64
# 7 资本收益和损失plt.subplots(figsize=(7,5))sns.distplot(df['capital_loss'][df['capital_loss']!=0])
C:\Anaconda\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).  warnings.warn(msg, FutureWarning)





<AxesSubplot:xlabel='capital_loss', ylabel='Density'>


在这里插入图片描述

资本损失密度函数,峰值在2000左右,呈现正态分布

plt.subplots(figsize=(7,5))sns.distplot(df['capital_gain'][df['capital_gain']!=0])
C:\Anaconda\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).  warnings.warn(msg, FutureWarning)





<AxesSubplot:xlabel='capital_gain', ylabel='Density'>


在这里插入图片描述

收益普遍比较高,在10万美金出现很多极值

df['capital_gain'][(df['capital_gain']!=0)|df['capital_loss']].agg(['mean','count'])
mean     8293.387852count    4231.000000Name: capital_gain, dtype: float64
df['capital_loss'][(df['capital_gain']!=0)|df['capital_loss']].agg(['mean','count'])
mean      671.874261count    4231.000000Name: capital_loss, dtype: float64

在有资本收益损失的调查人群6317人中,其中资本获益人均8343,资本损失人均676,看来行情不错。

gain_age=df.groupby(['age'])['capital_gain'].sum()loss_age=df.groupby(['age'])['capital_loss'].sum()fig =plt.figure(figsize=(15,6))sns.barplot(gain_age.index,gain_age.values,color='yellow',alpha=0.7,label='capital_gain')sns.barplot(loss_age.index,loss_age.values,color='red',alpha = 0.3,label='capital_loss')fig.legend(ncol=2, loc="upper center", frameon=True)
C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(C:\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.  warnings.warn(





<matplotlib.legend.Legend at 0x208a834b550>


在这里插入图片描述

看下分年龄的资本收入情况对比。

# 8 工作时长plt.subplots(figsize=(9,5))sns.distplot(df['hours_per_week'])
C:\Anaconda\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).  warnings.warn(msg, FutureWarning)





<AxesSubplot:xlabel='hours_per_week', ylabel='Density'>


在这里插入图片描述

接近50%的人是每周工作40个小时。

plt.subplots(figsize=(8, 9))s=df['native_country'][df['native_country']!='United-States'].value_counts()sns.barplot(y=s.index,x=            s.values)
<AxesSubplot:>


在这里插入图片描述

可以看到前三移民来源国是Mexico,Philippines,Germany


2.3 缺失值情况分析

import missingno as msno # 可视化缺失值的库
df.info()
<class 'pandas.core.frame.DataFrame'>RangeIndex: 32561 entries, 0 to 32560Data columns (total 15 columns): #   Column          Non-Null Count  Dtype ---  ------          --------------  -----  0   age             32561 non-null  int64  1   workclass       30725 non-null  object 2   final_weight    32561 non-null  int64  3   education       32561 non-null  object 4   education_num   32561 non-null  int64  5   marital_status  32561 non-null  object 6   occupation      30718 non-null  object 7   relationship    32561 non-null  object 8   race            32561 non-null  object 9   sex             32561 non-null  object 10  capital_gain    32561 non-null  int64  11  capital_loss    32561 non-null  int64  12  hours_per_week  32561 non-null  int64  13  native_country  31978 non-null  object 14  salary          32561 non-null  objectdtypes: int64(6), object(9)memory usage: 3.7+ MB

三列有缺失,workclass,occupation,native_country

msno.bar(df)
<AxesSubplot:>


在这里插入图片描述

绘制缺失值热力图。利用热力图可以观察多个特征两两的相似度,相似度由皮尔逊相关系数度量。

occupation和workclass为1表明这两个变量在数据集中是同步缺失的。

msno.heatmap(df,figsize=(3,2))
<AxesSubplot:>


在这里插入图片描述

分别对训练集和测试集绘制缺失值矩阵图。矩阵图中白线越多,代表缺失值越多。

结果表明workclass和occupation相比于native_country有更多的缺失值

msno.matrix(df,figsize=(6,3))
<AxesSubplot:>


在这里插入图片描述

工作类型workclass和职业occupation分别有5.63%和5.66%的缺失,原籍native_country有1.79%缺失

temp = df.apply(lambda x:x.isna().sum()/len(x))# temp = temp.loc[:,(temp!=0).any()]temp
age               0.000000workclass         0.056386final_weight      0.000000education         0.000000education_num     0.000000marital_status    0.000000occupation        0.056601relationship      0.000000race              0.000000sex               0.000000capital_gain      0.000000capital_loss      0.000000hours_per_week    0.000000native_country    0.017905salary            0.000000dtype: float64
temp.plot(kind='bar',figsize=(8,4))
<AxesSubplot:>


在这里插入图片描述

2.4 数据处理 及特征工程

发现数据存在occupation缺失而workclass为"Never-worked"的情况,反之则不存在。

这是由于无工作者没有职业,此部分可直接将这些occupation填补为一个新的类即可。

df[['occupation','workclass']][(df['occupation'].isna()==True) & (df['workclass'].isna()==False)]
occupationworkclass
5361NaNNever-worked
10845NaNNever-worked
14772NaNNever-worked
20337NaNNever-worked
23232NaNNever-worked
32304NaNNever-worked
32314NaNNever-worked
df.loc[df['workclass']=='Never_worked','occupation'] = 'Never_worked'
df['workclass'].value_counts()
Private             22696Self-emp-not-inc     2541Local-gov            2093State-gov            1298Self-emp-inc         1116Federal-gov           960Without-pay            14Never-worked            7Name: workclass, dtype: int64
df.head()
ageworkclassfinal_weighteducationeducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
df.sex.value_counts()
Male      21790Female    10771Name: sex, dtype: int64
data = df.copy(deep=True)
# 处理性别data['sex'] = 1*(data['sex']=='Male')
data.head()
ageworkclassfinal_weighteducationeducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhite12174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhite10013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhite10040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlack10040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlack00040Cuba<=50K
# 删除educationdel data['education']
data.head()
ageworkclassfinal_weighteducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov7751613Never-marriedAdm-clericalNot-in-familyWhite12174040United-States<=50K
150Self-emp-not-inc8331113Married-civ-spouseExec-managerialHusbandWhite10013United-States<=50K
238Private2156469DivorcedHandlers-cleanersNot-in-familyWhite10040United-States<=50K
353Private2347217Married-civ-spouseHandlers-cleanersHusbandBlack10040United-States<=50K
428Private33840913Married-civ-spouseProf-specialtyWifeBlack00040Cuba<=50K
# 处理salary 标签 data['salary'] = 1*(df['salary']=='>50K')
data.salary.value_counts()
0    247201     7841Name: salary, dtype: int64
data.head(5)
ageworkclassfinal_weighteducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov7751613Never-marriedAdm-clericalNot-in-familyWhite12174040United-States0
150Self-emp-not-inc8331113Married-civ-spouseExec-managerialHusbandWhite10013United-States0
238Private2156469DivorcedHandlers-cleanersNot-in-familyWhite10040United-States0
353Private2347217Married-civ-spouseHandlers-cleanersHusbandBlack10040United-States0
428Private33840913Married-civ-spouseProf-specialtyWifeBlack00040Cuba0
# 删除需要独热编码的列,防止列名重复导致编码失败data.drop(['workclass','marital_status','occupation','relationship','race','native_country'],axis=1,inplace=True)
data.head()
agefinal_weighteducation_numsexcapital_gaincapital_losshours_per_weeksalary
0397751613121740400
1508331113100130
2382156469100400
3532347217100400
42833840913000400
df.head()
ageworkclassfinal_weighteducationeducation_nummarital_statusoccupationrelationshipracesexcapital_gaincapital_losshours_per_weeknative_countrysalary
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
data = data.join(pd.get_dummies(df.workclass))
data = data.join(pd.get_dummies(df.marital_status))
data.columns
Index(['age', 'final_weight', 'education_num', 'sex', 'capital_gain',       'capital_loss', 'hours_per_week', 'salary', 'Federal-gov', 'Local-gov',       'Never-worked', 'Private', 'Self-emp-inc', 'Self-emp-not-inc',       'State-gov', 'Without-pay', 'Divorced', 'Married-AF-spouse',       'Married-civ-spouse', 'Married-spouse-absent', 'Never-married',       'Separated', 'Widowed'],      dtype='object')
data = data.join(pd.get_dummies(df.occupation))
data.head()
agefinal_weighteducation_numsexcapital_gaincapital_losshours_per_weeksalaryFederal-govLocal-gov...Farming-fishingHandlers-cleanersMachine-op-inspctOther-servicePriv-house-servProf-specialtyProtective-servSalesTech-supportTransport-moving
039775161312174040000...0000000000
150833111310013000...0000000000
238215646910040000...0100000000
353234721710040000...0100000000
4283384091300040000...0000010000

5 rows × 37 columns


data = data.join(pd.get_dummies(df.relationship))data = data.join(pd.get_dummies(df.race))data = data.join(pd.get_dummies(df.native_country))
data.head()
agefinal_weighteducation_numsexcapital_gaincapital_losshours_per_weeksalaryFederal-govLocal-gov...PortugalPuerto-RicoScotlandSouthTaiwanThailandTrinadad&TobagoUnited-StatesVietnamYugoslavia
039775161312174040000...0000000100
150833111310013000...0000000100
238215646910040000...0000000100
353234721710040000...0000000100
4283384091300040000...0000000000

5 rows × 89 columns

3. 训练

3.1 划分数据集

data_ = data.copy(deep=True)
y = data_.salary
del data_['salary']

X = data_.iloc[:,:]
X.head()
agefinal_weighteducation_numsexcapital_gaincapital_losshours_per_weekFederal-govLocal-govNever-worked...PortugalPuerto-RicoScotlandSouthTaiwanThailandTrinadad&TobagoUnited-StatesVietnamYugoslavia
039775161312174040000...0000000100
150833111310013000...0000000100
238215646910040000...0000000100
353234721710040000...0000000100
4283384091300040000...0000000000

5 rows × 88 columns

from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(X,y)

1.随机森林

from sklearn.model_selection import GridSearchCVfrom sklearn.ensemble import RandomForestClassifier
# 随机森林调参,跑了几个小时才跑完
# RF = RandomForestClassifier()# param_grid = {'max_features':['auto','sqrt','log2'],#               'max_depth':np.arange(1,15),#               'min_samples_leaf':np.arange(1,20),#               'min_samples_split':[0.1,0.3,0.5],#               'criterion':['gini','entropy']#              }# GS = GridSearchCV(RF,param_grid=param_grid,cv=5,scoring='roc_auc').fit(x_train,y_train)# print(GS.best_params_)# print(GS.best_score_)# print(GS.best_estimator_)"""{'criterion': 'gini', 'max_depth': 12, 'max_features': 'auto', 'min_samples_leaf': 1, 'min_samples_split': 0.1}0.9044572707834764RandomForestClassifier(max_depth=12, min_samples_split=0.1)"""
"\n{'criterion': 'gini', 'max_depth': 12, 'max_features': 'auto', 'min_samples_leaf': 1, 'min_samples_split': 0.1}\n0.9044572707834764\nRandomForestClassifier(max_depth=12, min_samples_split=0.1)\n"
RF = RandomForestClassifier(criterion='gini', max_depth = 12, max_features='auto', min_samples_leaf = 3, min_samples_split = 0.1).fit(x_train,y_train)
RF.score(x_train,y_train)
0.8416871416871416
from sklearn.metrics import classification_report
print(classification_report(y_test,RF.predict(x_test)))
              precision    recall  f1-score   support           0       0.85      0.97      0.91      6170           1       0.82      0.47      0.60      1971    accuracy                           0.85      8141   macro avg       0.84      0.72      0.75      8141weighted avg       0.84      0.85      0.83      8141

2 SVM 支持向量机

from sklearn import svm
svc_bal = svm.SVC(class_weight='balanced')  #自动调整不平衡样本svc_bal.fit(x_train,y_train)
SVC(class_weight='balanced')
print(classification_report(y_test,svc_bal.predict(x_test)))
              precision    recall  f1-score   support           0       0.79      0.99      0.88      6170           1       0.86      0.19      0.31      1971    accuracy                           0.80      8141   macro avg       0.83      0.59      0.60      8141weighted avg       0.81      0.80      0.74      8141

3 逻辑回归

from sklearn.linear_model import LogisticRegression,LogisticRegressionCV

4 Adaboost

from sklearn.ensemble import AdaBoostClassifierfrom sklearn.tree import DecisionTreeClassifier
ada = AdaBoostClassifier(    base_estimator = DecisionTreeClassifier(max_depth=2),    n_estimators=50,    learning_rate=1.0).fit(x_train,y_train)
print(classification_report(y_test,ada.predict(x_test)))
              precision    recall  f1-score   support           0       0.89      0.94      0.92      6170           1       0.78      0.64      0.70      1971    accuracy                           0.87      8141   macro avg       0.83      0.79      0.81      8141weighted avg       0.86      0.87      0.86      8141

ada.score(x_test,y_test)
0.8691806903328829

5 GDBT

from sklearn.ensemble import GradientBoostingClassifier
gbdt = GradientBoostingClassifier().fit(x_train,y_train)
print(classification_report(y_test,gbdt.predict(x_test)))
              precision    recall  f1-score   support           0       0.89      0.95      0.92      6170           1       0.80      0.61      0.69      1971    accuracy                           0.87      8141   macro avg       0.84      0.78      0.81      8141weighted avg       0.86      0.87      0.86      8141

6 XGB

from xgboost import XGBClassifierfrom xgboost import XGBRegressor
sk_xgb_c = XGBClassifier().fit(x_train,y_train)
C:\Anaconda\lib\site-packages\xgboost\sklearn.py:1146: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].  warnings.warn(label_encoder_deprecation_msg, UserWarning)


[17:23:12] WARNING: C:/Users/Administrator/workspace/xgboost-win64_release_1.4.0/src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
print(classification_report(y_test,sk_xgb_c.predict(x_test)))
              precision    recall  f1-score   support           0       0.90      0.94      0.92      6170           1       0.78      0.67      0.72      1971    accuracy                           0.87      8141   macro avg       0.84      0.80      0.82      8141weighted avg       0.87      0.87      0.87      8141

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值