随机森林
首先吐槽下,这个星期作业实在是太多了……一周两个机器学习项目实在伤不起啊!!!所以这一次的随机森林我决定放点水,决策树部分就不自己写了,还是调库吧……当然随机森林部分还是得自己写的。
事实上,如果决策树部分直接调库的话,随机森林可能是实现起来最简单的机器学习算法了。大体思想非常简单:
从训练数据中随机抽取出一部分,用来训练一颗决策树
将训练好的决策树加入到列表中
预测时,使用列表中的决策树对数据进行预测,然后进行投票,得票数过半的分类为1,否则分类为0
那么,直接开始写吧……首先训练一堆决策树,并且放入列表:
def fit(self, data, label, tree_count=100):
data_size = len(label)
for idx in range(tree_count):
# 抽取训练数据的1/4为样本,使用SkLearn中的决策树进行训练
sample_idx = np.random.randint(0, data_size,
data_size // 4, np.int)
sample_data, sample_label = data[sample_idx],
label[sample_idx]
tree = DecisionTreeRegressor()
tree.fit(sample_data, sample_label)
self.__trees.append(tree)
然后在预测的时候进行投票,因为分类结果要么是0,要么是1,所以偷个懒,直接把结果加在一起,如果最后某个数据的值大于决策树数量的一半,那么很明显它获得1的次数超过一半,代码如下:
def predict(self, data):
result = np.zeros([len(data)], np.int)
for tree in self.__trees:
# 投票操作,将票数累计
result += tree.predict(data).astype(np.int)
# 票数过半的为正样本,否则为负样本
pos = np.where(result >= len(self.__trees) // 2)
neg = np.where(result < len(self.__trees) // 2)
result[pos], result[neg] = 1, 0
return result
最后在Main里面读入数据,进行测试。本次使用Wine.data数据集,这个数据集相对Iris数据集比较复杂,在之前的对率回归分类器中准确率不到80%。Main的代码如下:
if __name__ == '__main__':
file = open('Data/wine.data')
data_str = file.readlines()
np.random.shuffle(data_str)
file.close()
# 读取数据并进行预处理
value = np.ndarray([len(data_str), 13], np.float32)
label = np.ndarray([len(data_str)], np.int32)
for outer_idx in range(len(data_str)):
data = data_str[outer_idx].strip('\n').split(',')
value[outer_idx] = data[1:]
label[outer_idx] = data[0]
# 由于数据较为简单,生成数量为30的小森林
# 使用100条数据进行训练,30条数据进行验证
forest = RandomForest.RandomForest()
forest.fit(value[:100], label[:100], 100)
result = forest.predict(value[100:])
err_count = len(np.where(label[100:] != result)[0])
print('共测试 30 条样本,其中错误 %d 条,准确率 %.2f%%'
% (err_count, 100.0 - err_count * 100 / 30), end='\r')
不得不说样本数还是有点少啊……没办法,直接测试吧。
程序运行10次,结果如下:
共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 1 条,准确率 96.67%
可见准确率还是相当之高的,接近100%了。要不是作业太多时间不够,我十分愿意用它来挑战更有难度的数据集~源码可以点击这里查看。
那么这次就匆匆撒花了,滚去写作业啦~