使用决策树进行收入预测(Adult数据集)

本实验并不是用的C4.5算法,因为sklearn里面是没有真正实现C4.5算法的,本实验严格来说其实是cart算法。(本人小白一枚,纯粹为了记录完成结课作业的过程,代码丢带到jupyter里面完全可跑,如有问题记得踢我。)

一、数据集下载

UCI Machine Learning Repository

二、模型代码实现以及结果展示(使用sklearn库)

导入包

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

加载数据

def load_data(file_path):
    column_names = [
        "age", "workclass", "fnlwgt", "education", "education_num", "marital_status",
        "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss",
        "hours_per_week", "native_country", "income"
    ]
    data = pd.read_csv(file_path, names=column_names, na_values=" ?", skipinitialspace=True)
    data.head()
    return data

数据预处理

def preprocess_data(data):
    data = data.dropna()#删除缺失值
    label_encoders = {}
    for column in data.select_dtypes(include=['object']).columns:
        label_encoders[column] = LabelEncoder()
        data[column] = label_encoders[column].fit_transform(data[column])
    return data

模型训练+决策树可视化

def train_evaluate_model(data):
    X = data.drop('income', axis=1)
    y = data['income']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    clf = DecisionTreeClassifier(criterion='entropy', random_state=42, max_depth=5)#初始化一个决策树分类器模型。
    clf.fit(X_train, y_train)#用训练集来训练模型
    predictions = clf.predict(X_test)#用训练好的模型对测试集进行预测
    report = classification_report(y_test, predictions)#返回模型的分类报告。 

    # Visualize the decision tree
    plt.figure(figsize=(15, 10))
    plot_tree(clf, filled=True, feature_names=list(X.columns), class_names=['<=50K', '>50K'], rounded=True, fontsize=10)
    plt.tight_layout() 
    plt.show()
    
## 在训练集和测试集上分布利用训练好的模型进行预测
    train_predict = clf.predict(X_train)
    test_predict = clf.predict(X_test)
    from sklearn import metrics
    ## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
    print('The accuracy of the tree is:',metrics.accuracy_score(y_train,train_predict))
    print('The accuracy of the tree  is:',metrics.accuracy_score(y_test,test_predict))
    ## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
    confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
    print('The confusion matrix result:\n',confusion_matrix_result)
    # 混淆矩阵结果可视化
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
   
    plt.show()
    return report

主函数调用

def main():
    file_path = 'D:/postgraduate_study/pythonProject/adult/adult.data'  # Replace with your data file path
    data = load_data(file_path)
    data = preprocess_data(data)
    report = train_evaluate_model(data)
    print(report)
main()

结果展示

图一 决策树可视化

图二 混淆矩阵

图三 模型准确率

                                                                图四 模型评估

三、数据集可视化分析

进行简单的数据查看

file_path = 'D:/postgraduate_study/pythonProject/adult/adult.data'

data = pd.read_csv(file_path)
data.info()# 利用.info()查看数据的整体信息
data.head()# 进行简单的数据查看

柱状图、散点图展示

column_names = [
        "age", "workclass", "fnlwgt", "education", "education_num", "marital_status",
        "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss",
        "hours_per_week", "native_country", "income"
    ]
file_path = 'D:/postgraduate_study/pythonProject/adult/adult.data'  # Replace with your data file path
adult = pd.read_csv(file_path, names=column_names, na_values=" ?", skipinitialspace=True)
adult = adult.dropna()#删除缺失值
#从图1中可以直观看出高收入占比比较高的是执行管理(Exec-managerial)、专业教授(Prof-specialty)
#比较低的是清洁工(Handlers-cleaners)、养殖渔业(Farming-fishing)。
#高收入的职业往往是一些技术含量要高一些的工作或者是科研方面的。
sns.countplot(x= 'occupation', hue='income', data=adult)#hue:在x或y标签划分的同时,再以hue标签划分统计个数
plt.xticks(fontsize=6, rotation=-45) #调整x轴标签字体大小
plt.show()

#通过对数据和图形的分析可知,“education”是标称属性
#取值有:'HS-grad高中毕业', 'Some-college', 'Bachelors学士', 'Masters硕士', 'Assoc-voc职业学校', '11th', 'Assoc-acdm'
#'10th', '7th-8th', 'Prof-school', '9th', '12th', 'Doctorate', '5th-6th', '1st-4th', 'Preschool'
#从图2可以看出,学历越高收入>50K的占比越高,大部分人受过高等教育。
sns.countplot(x= 'education', hue='income', data=adult)
plt.xticks(fontsize=6, rotation=-45) #调整x轴标签字体大小
plt.show()
#创建散点图可视化
data_san = pd.read_csv('D:/postgraduate_study/pythonProject/adult/adult.data', names=column_names, na_values=" ?", skipinitialspace=True)
sns.pairplot(data_san, hue='income')
plt.show()

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值