本次作业是运用决策树算法,来进行微博分类。
这是第一次写机器学习项目,所以代码真的很幼稚,但也是真的学到了很多代码在这里:代码
代码整体结构:
整个代码分成三个文件,
一个决策树相关的函数类(clas.py)
一个用于训练的文件(train.py)
一个用于测试的文件(test.py)
train.py :这是训练算法的主文件,包括读取文件、构造特征值、数据降维、数据保存等功能
其中:数据降维采用基于特征选择的降维,通过分析每个词的文档频率,只保留那些出现的文档频率最高的那些词。换一个角度,我移除只出现在最多x 个文档中的所有词。最后经过测试,我删除了出现次数小于20的特征值。这使得我的算法正确率有所提高,重点是提升了程序的运行速度。
clas.py :包含一个决策树的类,含有创建决策树,计算信息熵,选取最佳特征值,预测等功能。其中,决策树运用的特征选取的规则:信息增益,此处借鉴了网上“笨鸟先飞”《机器学习实战之决策树》和 “超级杰哥”《 python 计算信息熵和信息增益》的算法思路
test.py :测试集,包含读取数据