Shark源码分析(八):CART算法
决策树算法是机器学习中非常常用的一种算法。在我关于机器学习的博客中有对决策树算法进行详细的介绍。在Shark中,只实现了CART这一种类型的决策树,它可以用于分类任务或是回归任务中。在这里我们只对其中有关分类任务的部分代码进行分析。
CARTClassifier类
这个类用于定义决策树,该类定义在文件<include/shark/Models/Trees/CARTClassifier.h>
中。
template<class LabelType>
class CARTClassifier : public AbstractModel<RealVector,LabelType>
{
private:
typedef AbstractModel<RealVector, LabelType> base_type;
public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
//定义决策树的结点类
struct NodeInfo {
std::size_t nodeId; //结点的标号
std::size_t attributeIndex; //在该结点划分属性的编号
double attributeValue; //划分属性的值
std::size_t leftNodeId; //左儿子结点的编号
std::size_t rightNodeId; //右儿子结点的编号
LabelType label; //数据在该结点的类标签,之后可以看到,这其实是一个向量,表明类的隶属度
double misclassProp; //假设该结点为叶结点,以该结点中数据最多的类别为类标签的分类错误率
//之后的这两个值用于剪枝过程
std::size_t r;
double g;
template<class Archive>
void serialize(Archive & ar, const unsigned int version){
ar & nodeId;
ar & attributeIndex;
ar & attributeValue;
ar & leftNodeId;
ar & rightNodeId;
ar & label;
ar & misclassProp;
ar & r;
ar & g;
}
};
//用数组的形式来存储决策树,这可能会让你想到最大堆或是最小堆的组织方式
//但之后你会发现,并不是这样的,因为一棵决策树不一定是一棵平衡二叉树
typedef std::vector<NodeInfo> TreeType;
CARTClassifier()
{}
CARTClassifier(TreeType const& tree)
{
m_tree=tree;
}
//这里的optimize指的是能以常数级的访问时间访问到树中的结点
CARTClassifier(TreeType const& tree, bool optimize)
{
if (optimize)
setTree(tree);
else
m_tree=tree;
}
CARTClassifier(TreeType const& tree, std::size_t d)
{
setTree(tree);
m_inputDimension = d;
}
std::string name() const
{ return "CARTClassifier"; }
boost::shared_ptr<State> createState() const{
return boost::shared_ptr<State>(new EmptyState());
}
using base_type::eval;
//根据输入数据,计算它们的类别
void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
std::size_t numPatterns = shark::size(patterns);
LabelType const& firstResult = evalPattern(row(patterns,0));
outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
get(outputs,0) = firstResult;
for(std::size_t i = 0; i != numPatterns; ++i){
get(outputs,i) = evalPattern(row(patterns,i));
}
}
void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
eval(patterns,outputs);
}
void eval(RealVector const& pattern, LabelType& output){
output = evalPattern(pattern);
}
void setTree(TreeType const& tree){
m_tree = tree;
optimizeTree(m_tree);
}
TreeType getTree() const {
return m_tree;
}
std::size_t numberOfParameters() const{
return 0;
}
RealVector parameterVector() const {
return RealVector();
}
void setParameterVector(RealVector const& param) {
SHARK_ASSERT(param.size() == 0);
}
void read(InArchive& archive){
archive >> m_tree;
}
void write(OutArchive& archive) const {
archive << m_tree;
}
//计算每一个属性被用作划分属性的次数,因为数据的每一维属性都是连续属性,所以可以被用作划分属性多次
UIntVector countAttributes() const {
SHARK_ASSERT(m_inputDimension > 0);
UIntVector r(m_inputDimension, 0);
typename TreeType::const_iterator it;
for(it = m_tree.begin(); it != m_tree.end(); ++it) {
//std::cout << "NodeId: " <<it->leftNodeId << std::endl;
if(it->leftNodeId != 0) { // not a label
r(it->attributeIndex)++;
}
}
return r;
}
std::size_t inputSize() const {
return m_inputDimension;
}
void setInputDimension(std::size_t d) {
m_inputDimension = d;
}
//根据输入数据,计算模型的分类误差
//针对回归任务和分类任务有不同的重载版本,目标函数不一样
void computeOOBerror(const ClassificationDataset& dataOOB){
ZeroOneLoss<unsigned int, RealVector> lossOOB;
// 这里输出的类标签是一个向量,表明对每个类的隶属度
Dat