Shark源码分析(八):CART算法

Shark源码分析(八):CART算法

决策树算法是机器学习中非常常用的一种算法。在我关于机器学习的博客中有对决策树算法进行详细的介绍。在Shark中,只实现了CART这一种类型的决策树,它可以用于分类任务或是回归任务中。在这里我们只对其中有关分类任务的部分代码进行分析。

CARTClassifier类

这个类用于定义决策树,该类定义在文件<include/shark/Models/Trees/CARTClassifier.h>中。

template<class LabelType>
class CARTClassifier : public AbstractModel<RealVector,LabelType>
{
private:
    typedef AbstractModel<RealVector, LabelType> base_type;
public:
    typedef typename base_type::BatchInputType BatchInputType;
    typedef typename base_type::BatchOutputType BatchOutputType;

    //定义决策树的结点类
    struct NodeInfo {
        std::size_t nodeId; //结点的标号
        std::size_t attributeIndex; //在该结点划分属性的编号
        double attributeValue; //划分属性的值
        std::size_t leftNodeId; //左儿子结点的编号
        std::size_t rightNodeId; //右儿子结点的编号
        LabelType label; //数据在该结点的类标签,之后可以看到,这其实是一个向量,表明类的隶属度
        double misclassProp; //假设该结点为叶结点,以该结点中数据最多的类别为类标签的分类错误率

        //之后的这两个值用于剪枝过程
        std::size_t r;
        double g;

       template<class Archive>
       void serialize(Archive & ar, const unsigned int version){
            ar & nodeId;
            ar & attributeIndex;
            ar & attributeValue;
            ar & leftNodeId;
            ar & rightNodeId;
            ar & label;
            ar & misclassProp;
            ar & r;
            ar & g;
        }
    };

    //用数组的形式来存储决策树,这可能会让你想到最大堆或是最小堆的组织方式
    //但之后你会发现,并不是这样的,因为一棵决策树不一定是一棵平衡二叉树
    typedef std::vector<NodeInfo> TreeType;

    CARTClassifier()
    {}

    CARTClassifier(TreeType const& tree)
    {
        m_tree=tree;
    }

    //这里的optimize指的是能以常数级的访问时间访问到树中的结点
    CARTClassifier(TreeType const& tree, bool optimize)
    {
        if (optimize)
            setTree(tree);
        else
            m_tree=tree;
    }

    CARTClassifier(TreeType const& tree, std::size_t d)
    {
        setTree(tree);
        m_inputDimension = d;
    }

    std::string name() const
    { return "CARTClassifier"; }

    boost::shared_ptr<State> createState() const{
        return boost::shared_ptr<State>(new EmptyState());
    }

    using base_type::eval;
    //根据输入数据,计算它们的类别
    void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
        std::size_t numPatterns = shark::size(patterns);

        LabelType const& firstResult = evalPattern(row(patterns,0));
        outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
        get(outputs,0) = firstResult;

        for(std::size_t i = 0; i != numPatterns; ++i){
            get(outputs,i) = evalPattern(row(patterns,i));
        }
    }

    void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
        eval(patterns,outputs);
    }

    void eval(RealVector const& pattern, LabelType& output){
        output = evalPattern(pattern);  
    }

    void setTree(TreeType const& tree){
        m_tree = tree;
        optimizeTree(m_tree);
    }

    TreeType getTree() const {
        return m_tree;
    }

    std::size_t numberOfParameters() const{
        return 0;
    }

    RealVector parameterVector() const {
        return RealVector();
    }

    void setParameterVector(RealVector const& param) {
        SHARK_ASSERT(param.size() == 0);
    }

    void read(InArchive& archive){
        archive >> m_tree;
    }

    void write(OutArchive& archive) const {
        archive << m_tree;
    }

    //计算每一个属性被用作划分属性的次数,因为数据的每一维属性都是连续属性,所以可以被用作划分属性多次
    UIntVector countAttributes() const {
        SHARK_ASSERT(m_inputDimension > 0);
        UIntVector r(m_inputDimension, 0);
        typename TreeType::const_iterator it;
        for(it = m_tree.begin(); it != m_tree.end(); ++it) {
            //std::cout << "NodeId: " <<it->leftNodeId << std::endl;
            if(it->leftNodeId != 0) { // not a label 
                r(it->attributeIndex)++;
            }
        }
        return r;
    }

    std::size_t inputSize() const {
        return m_inputDimension;
    }

    void setInputDimension(std::size_t d) {
        m_inputDimension = d;
    }

    //根据输入数据,计算模型的分类误差
    //针对回归任务和分类任务有不同的重载版本,目标函数不一样
    void computeOOBerror(const ClassificationDataset& dataOOB){
        ZeroOneLoss<unsigned int, RealVector> lossOOB;

        // 这里输出的类标签是一个向量,表明对每个类的隶属度
        Dat
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值