机器学习作业7 - 随机森林

随机森林

首先吐槽下,这个星期作业实在是太多了……一周两个机器学习项目实在伤不起啊!!!所以这一次的随机森林我决定放点水,决策树部分就不自己写了,还是调库吧……当然随机森林部分还是得自己写的。

事实上,如果决策树部分直接调库的话,随机森林可能是实现起来最简单的机器学习算法了。大体思想非常简单:

  • 从训练数据中随机抽取出一部分,用来训练一颗决策树

  • 将训练好的决策树加入到列表中

  • 预测时,使用列表中的决策树对数据进行预测,然后进行投票,得票数过半的分类为1,否则分类为0

那么,直接开始写吧……首先训练一堆决策树,并且放入列表:

    def fit(self, data, label, tree_count=100):
        data_size = len(label)
        for idx in range(tree_count):
            # 抽取训练数据的1/4为样本,使用SkLearn中的决策树进行训练
            sample_idx = np.random.randint(0, data_size,
                         data_size // 4, np.int)
            sample_data, sample_label = data[sample_idx],
                         label[sample_idx]
            tree = DecisionTreeRegressor()
            tree.fit(sample_data, sample_label)
            self.__trees.append(tree)

然后在预测的时候进行投票,因为分类结果要么是0,要么是1,所以偷个懒,直接把结果加在一起,如果最后某个数据的值大于决策树数量的一半,那么很明显它获得1的次数超过一半,代码如下:

    def predict(self, data):
        result = np.zeros([len(data)], np.int)
        for tree in self.__trees:
            # 投票操作,将票数累计
            result += tree.predict(data).astype(np.int)
        # 票数过半的为正样本,否则为负样本
        pos = np.where(result >= len(self.__trees) // 2)
        neg = np.where(result < len(self.__trees) // 2)
        result[pos], result[neg] = 1, 0
        return result

最后在Main里面读入数据,进行测试。本次使用Wine.data数据集,这个数据集相对Iris数据集比较复杂,在之前的对率回归分类器中准确率不到80%。Main的代码如下:

if __name__ == '__main__':
    file = open('Data/wine.data')
    data_str = file.readlines()
    np.random.shuffle(data_str)
    file.close()
    # 读取数据并进行预处理
    value = np.ndarray([len(data_str), 13], np.float32)
    label = np.ndarray([len(data_str)], np.int32)
    for outer_idx in range(len(data_str)):
        data = data_str[outer_idx].strip('\n').split(',')
        value[outer_idx] = data[1:]
        label[outer_idx] = data[0]
    # 由于数据较为简单,生成数量为30的小森林
    # 使用100条数据进行训练,30条数据进行验证
    forest = RandomForest.RandomForest()
    forest.fit(value[:100], label[:100], 100)
    result = forest.predict(value[100:])
    err_count = len(np.where(label[100:] != result)[0])
    print('共测试 30 条样本,其中错误 %d 条,准确率 %.2f%%'
          % (err_count, 100.0 - err_count * 100 / 30), end='\r')

不得不说样本数还是有点少啊……没办法,直接测试吧。
程序运行10次,结果如下:

共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 1 条,准确率 96.67%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 0 条,准确率 100.00%
共测试 30 条样本,其中错误 1 条,准确率 96.67%

可见准确率还是相当之高的,接近100%了。要不是作业太多时间不够,我十分愿意用它来挑战更有难度的数据集~源码可以点击这里查看。

那么这次就匆匆撒花了,滚去写作业啦~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值