Shark源码分析(十二):线性SVM

Shark源码分析(十二):线性SVM

关于svm算法,这个在我关于机器学习的博客中已经描述的比较详实了,这里就不再赘述。svm主要有三种类型,这里我所介绍的是线性svm算法的代码。相较于使用核函数的svm算法,代码的整体框架应该是一样的,只是在对偶问题的求解上所使用的方法可能是不一样的。

LinearClassifier类

这个类所表示的是算法的决策平面,是一个多分类的线性分类模型。定义在<include/shark/Models/LinearClassifier.h>中。

template<class VectorType = RealVector>
class LinearClassifier : public ArgMaxConverter<LinearModel<VectorType> >
{
public:
    LinearClassifier(){}

    std::string name() const
    { return "LinearClassifier"; }
};

相当简单的一个类,并没有什么好说明的地方。

ArgMaxConverter类

该类是LinearClassifier的基类,其作用是将一个输出的向量通过arg_max操作转变为一个类标记,就是输出分量最大的那一维。该类定义在<include/shark/Models/Converter.h>

template<class Model>
class ArgMaxConverter : public AbstractModel<typename Model::InputType, unsigned int>
{
private:
    typedef typename Model::BatchOutputType ModelBatchOutputType;
public:
    typedef typename Model::InputType InputType;
    typedef unsigned int OutputType;
    typedef typename Batch<InputType>::type BatchInputType;
    typedef Batch<unsigned int>::type BatchOutputType;

    ArgMaxConverter()
    { }
    ArgMaxConverter(Model const& decisionFunction)
    : m_decisionFunction(decisionFunction)
    { }

    std::string name() const
    { return "ArgMaxConverter<"+m_decisionFunction.name()+">"; }

    RealVector parameterVector() const{
        return m_decisionFunction.parameterVector();
    }

    void setParameterVector(RealVector const& newParameters){
        m_decisionFunction.setParameterVector(newParameters);
    }

    std::size_t numberOfParameters() const{
        return m_decisionFunction.numberOfParameters();
    }

    Model const& decisionFunction()const{
        return m_decisionFunction;
    }

    Model& decisionFunction(){
        return m_decisionFunction;
    }

    // 计算输入数据的类标签
    void eval(BatchInputType const& input, BatchOutputType& output)const{

        ModelBatchOutputType modelResult;
        m_decisionFunction.eval(input,modelResult);
        std::size_t batchSize = shark::size(modelResult);
        output.resize(batchSize);
        if(modelResult.size2()== 1) //对于二分类的情况
        {
            for(std::size_t i = 0; i != batchSize; ++i){// 如果输出大于0表示正类,否则为负类
                output(i) = modelResult(i,0) > 0.0;
            }
        }
        else{
            for(std::size_t i = 0; i != batchSize; ++i){
                output(i) = static_cast<unsigned int>(arg_max(row(modelResult,i)));
            }
        }
    }

    void eval(BatchInputType const& input, BatchOutputType& output, State& state)const{
        eval(input,output);
    }

    void eval(InputType const & pattern, OutputType& output)const{
        typename Model::OutputType modelResult;
        m_decisionFunction.eval(pattern,modelResult);
        if(modelResult.size()== 1){
            output = modelResult(0) > 0.0;
        }
        else{
            output = static_cast<unsigned int>(arg_max(modelResult));
        }
    }

    void read(InArchive& archive){
        archive >> m_decisionFunction;
    }

    void write(OutArchive& archive) const{
        archive << m_decisionFunction;
    }

private:
    Model m_decisionFunction;
};

在LinearClassifier类的代码中,该模板类的模板参数是LinearModel,这个模板类我们之前已经介绍过了。

AbstractLinearSvmTrainer类

这个类是所有线性svm训练方法的基类。该类定义在<include/shark/Algorithms/Trainers/AbstractSvmTrainer.h>中。

template <class InputType>
class AbstractLinearSvmTrainer
: public AbstractTrainer<LinearClassifier<InputType>, unsigned int>
, public QpConfig
, public IParameterizable
{
public:
    typedef AbstractTrainer<LinearClassifier<InputType>, unsigned int> base_type;
    typedef LinearClassifier<InputType> ModelType;

    AbstractLinearSvmTrainer(double C, bool unconstrained = false)
    : m_C(C)
    , m_unconstrained(unconstrained)
    { RANGE_CHECK( C > 0 ); }

    double C() const
    { return m_C; }

    void setC(double C) {
        RANGE_CHECK( C > 0 );
        m_C = C;
    }

    bool isUnconstrained() const
    { return m_unconstrained; }

    RealVector parameterVector() const
    {
        RealVector ret(1);
        ret(0) = (m_unconstrained ? std::log(m_C) : m_C);
        return ret;
    }

    void setParameterVector(RealVector const& newParameters)
    {
        SHARK_ASSERT(newParameters.size() == 1);
        setC(m_unconstrained ? std::exp(newParameters(0)) : newParameters(0));
    }

    size_t numberOfParameters() const
    { return 1; }

    // 对于以下的两个类成员,在QpConfig的构造函数中没有为它们赋值。稍后可以看到,它们自己的构造函数中是有默认值的
    using QpConfig::m_stoppingcondition; // 算法训练的停止条件
    using QpConfig::m_solutionproperties; // 当前解的一些性质
    using QpConfig::m_verbosity; // 冗长程度(字面翻译,在后面的代码中体现的不是这个意思),默认值是0,

protected:
    double m_C; // 目标函数中正则化项的系数
    bool m_unconstrained; // 是否使用log C 代替了C,如果是的话则摆脱了C > 0的限制,并不知道这个有什么用
};

QpStoppingCondition类、QpStopType类和QpSolutionProperties类

这三个类都定义在<include/shark/Algorithms/QP/QuadraticProgram.h>中。

struct QpStoppingCondition
{
    QpStoppingCondition(double accuracy = 0.001, unsigned long long iterations = 0xffffffff, double value = 1e100, double seconds = 1e100)
    {
        minAccuracy = accuracy;
        maxIterations = iterations;
        targetValue = value;
        maxSeconds = seconds;
    }

    //违反KKT条件的阈值下限
    double minAccuracy;

    //最大迭代次数
    unsigned long long maxIterations;

    //目标函数值的阈值
    double targetValue;

    //算法运行的最长时间
    double maxSeconds;
};
enum QpStopType
{
    QpNone = 0,
    QpAccuracyReached = 1,
    QpMaxIterationsReached = 4,
    QpTimeout = 8,
};
struct QpSolutionProperties
{
    QpSolutionProperties()
    {
        type = QpNone;
        accuracy = 1e100;
        iterations = 0;
        value = 1e100;
        seconds = 0.0;
    }

    QpStopType type;
    double accuracy;     //当前解违反KKT条件的程度
    unsigned long long iterations; //当前循环的次数
    double value;    // 当前目标函数的值
    double seconds; // 当前程序的运行时间
};

LinearCSvmTrainer类

该类就是用于训练线性SVM的,定义在<include/shark/Algorithms/Trainers/CSvmTrainer.h>

template <class InputType>
class LinearCSvmTrainer : public AbstractLinearSvmTrainer<InputType>
{
public:
    typedef AbstractLinearSvmTrainer<InputType> base_type;

    LinearCSvmTrainer(double C, bool unconstrained = false) 
    : AbstractLinearSvmTrainer<InputType>(C, unconstrained){}

    std::string name() const
    { return "LinearCSvmTrainer"; }

    void train(LinearClassifier<InputType>& model, LabeledData<InputType, unsigned int> const& dataset)
    {
        std::size_t dim = inputDimension(dataset);
        QpBoxLinear<InputType> solver(dataset, dim);
        RealMatrix w(1, dim, 0.0);
        row(w, 0) = solver.solve(
                base_type::C(),
                0.0,
                QpConfig::stoppingCondition(),
                &QpConfig::solutionProperties(),
                QpConfig::verbosity() > 0);
        model.decisionFunction().setStructure(w);
    }
};

从代码中可以看出,主要还是调用QpBoxLinear类的solve方法来求解。

QpBoxLinear类

该类是利用矩阵分解的方法来求解目标函数是hinge损失函数的线性svm,定义在<include/shark/Algorithms/QP/QpBoxLinear.h>。光看代码你可能会不清楚一些操作的具体含义,需要看一下”A Dual Coordinate Descent Method for Large-scale Linear SVM”这篇论文。

template <class InputT>
class QpBoxLinear
{
public:
    typedef LabeledData<InputT, unsigned int> DatasetType;
    typedef typename LabeledData<InputT, unsigned int>::const_element_reference ElementType;

    QpBoxLinear(const DatasetType& dataset, std::size_t dim)
    : m_data(dataset)
    , m_xSquared(m_data.size())
    , m_dim(dim)
    {
        SHARK_ASSERT(dim > 0);

        for (std::size_t i=0; i<m_data.size(); i++)
        {
            ElementType x_i = m_data[i];
            m_xSquared(i) = inner_prod(x_i.input, x_i.input);
        }
    }

    // 参数reg相当于论文中的D_{ii}
    RealVector solve(
            double bound,
            double reg,
            QpStoppingCondition& stop,
            QpSolutionProperties* prop = NULL,
            bool verbose = false)
    {
        SHARK_ASSERT(bound > 0.0);
        SHARK_ASSERT(reg >= 0.0);

        Timer timer;
        timer.start();


        std::size_t ell = m_data.size();
        RealVector alpha(ell, 0.0); // 表示拉格朗日乘子
        RealVector w(m_dim, 0.0); // 权值向量
        RealVector pref(ell, 1.0);          // measure of success of individual steps
        double prefsum = ell;               // normalization constant
        std::vector<std::size_t> schedule(ell); // 更新每一个拉格朗日乘子的顺序

        // prepare counters
        std::size_t epoch = 0;
        std::size_t steps = 0;

        // prepare performance monitoring for self-adaptation
        double max_violation = 0.0;
        const double gain_learning_rate = 1.0 / ell;
        double average_gain = 0.0;
        bool canstop = true;

        // outer optimization loop
        while (true)
        {
            // 计算更新的下标顺序,至于这种算法的原理我就不是很清楚了,论文里也没有说明
            double psum = prefsum;
            prefsum = 0.0;
            std::size_t pos = 0;
            for (std::size_t i=0; i<ell; i++)
            {
                double p = pref[i];
                double num = (psum < 1e-6) ? ell - pos : std::min((double)(ell - pos), (ell - pos) * p / psum);
                std::size_t n = (std::size_t)std::floor(num);
                double prob = num - n;
                if (Rng::uni() < prob) n++;
                for (std::size_t j=0; j<n; j++)
                {
                    schedule[pos] = i;
                    pos++;
                }
                psum -= p;
                prefsum += p;
            }
            SHARK_ASSERT(pos == ell);
            for (std::size_t i=0; i<ell; i++) std::swap(schedule[i], schedule[Rng::discrete(0, ell - 1)]);

            // inner loop
            // 算法的符号与论文中是相反的,包括g,pg和new_a的计算
            max_violation = 0.0;
            for (std::size_t j=0; j<ell; j++)
            {
                // active variable
                std::size_t i = schedule[j];
                ElementType e_i = m_data[i];
                double y_i = (e_i.label > 0) ? +1.0 : -1.0;

                // compute gradient and projected gradient
                double a = alpha(i);
                double wyx = y_i * inner_prod(w, e_i.input);
                double g = 1.0 - wyx - reg * a;
                double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == bound && g > 0.0 ? 0.0 : g);

                // update maximal KKT violation over the epoch
                max_violation = std::max(max_violation, std::abs(pg));
                double gain = 0.0;

                // 更新参数的过程
                if (pg != 0.0)
                {
                    // SMO-style coordinate descent step
                    double q = m_xSquared(i) + reg;
                    double mu = g / q; // 该参数同时也表示了a的两次更新之间的差值
                    double new_a = a + mu;

                    // numerically stable update
                    if (new_a <= 0.0)
                    {
                        mu = -a;
                        new_a = 0.0;
                    }
                    else if (new_a >= bound)
                    {
                        mu = bound - a;
                        new_a = bound;
                    }

                    // 更新参数
                    alpha(i) = new_a;
                    w += (mu * y_i) * e_i.input;
                    gain = mu * (g - 0.5 * q * mu);

                    steps++;
                }

                // update gain-based preferences
                {
                    if (epoch == 0) average_gain += gain / (double)ell;
                    else
                    {
                        double change = CHANGE_RATE * (gain / average_gain - 1.0);
                        double newpref = std::min(PREF_MAX, std::max(PREF_MIN, pref(i) * std::exp(change)));
                        prefsum += newpref - pref(i);
                        pref[i] = newpref;
                        average_gain = (1.0 - gain_learning_rate) * average_gain + gain_learning_rate * gain;
                    }
                }
            }

            epoch++;

            if (stop.maxIterations > 0 && ell * epoch >= stop.maxIterations) //这里的最大循环次数指的是内部循环的次数
            {
                if (prop != NULL) prop->type = QpMaxIterationsReached;
                break;
            }

            if (timer.stop() >= stop.maxSeconds)
            {
                if (prop != NULL) prop->type = QpTimeout;
                break;
            }

            if (max_violation < stop.minAccuracy)
            {
                if (verbose) std::cout << "#" << std::flush;
                if (canstop)
                {
                    if (prop != NULL) prop->type = QpAccuracyReached;
                    break;
                }
                else
                {
                    // prepare full sweep for a reliable checking of the stopping criterion
                    canstop = true;
                    for (std::size_t i=0; i<ell; i++) pref[i] = 1.0;
                    prefsum = ell;
                }
            }
            else
            {
                if (verbose) std::cout << "." << std::flush;
                canstop = false;
            }
        }

        timer.stop();

        // compute solution statistics
        std::size_t free_SV = 0; // 不在决策边界上的支持向量个数
        std::size_t bounded_SV = 0; // 在决策边界上的支持向量的个数
        double objective = -0.5 * shark::blas::inner_prod(w, w); //计算最终的目标函数值,但计算的值有些诡异,既不是原问题的目标值也不是对偶问题的目标值
        for (std::size_t i=0; i<ell; i++)
        {
            double a = alpha(i);
            if (a > 0.0)
            {
                objective += a;
                objective -= reg/2.0 * a * a;
                if (a == bound) bounded_SV++;
                else free_SV++;
            }
        }

        // return solution statistics
        if (prop != NULL)
        {
            prop->accuracy = max_violation;       // this is approximate, but a good guess
            prop->iterations = ell * epoch;
            prop->value = objective;
            prop->seconds = timer.lastLap();
        }

        // output solution statistics
        // 这里verbose只是一个是否需要输出信息的标志
        if (verbose)
        {
            std::cout << std::endl;
            std::cout << "training time (seconds): " << timer.lastLap() << std::endl;
            std::cout << "number of epochs: " << epoch << std::endl;
            std::cout << "number of iterations: " << (ell * epoch) << std::endl;
            std::cout << "number of non-zero steps: " << steps << std::endl;
            std::cout << "dual accuracy: " << max_violation << std::endl;
            std::cout << "dual objective value: " << objective << std::endl;
            std::cout << "number of free support vectors: " << free_SV << std::endl;
            std::cout << "number of bounded support vectors: " << bounded_SV << std::endl;
        }

        // return the solution
        return w;
    }

protected:
    DataView<const DatasetType> m_data; // 训练数据            
    RealVector m_xSquared; //m_data^T m_data                      
    std::size_t m_dim; // 输入数据的维度              
};

由于线性svm只是针对于二分类问题(当然所有的svm都是这样),如果要对多分类问题建立分类器,则需要使用LinearMcSvmOVATrainer类来训练。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值