Shark源码分析(五):线性回归算法与Lasso回归

Shark源码分析(五):线性回归算法与Lasso回归

为什么上一篇还是三,这一篇就跳到五了呢?其实我们原来提到过:

=++

这里的模型与算法我们之前都已经提到过了,虽然只是介绍了一个基类,并没有涉及到其具体的实现。在这里我们就会揭开其真正面目了。『策略』我们还没有介绍过,其实就是目标函数,在前面一些较为简单的算法中并没有涉及到这块。为了整个逻辑的完整性,我还是打算将其放在前面来介绍。

这里我们所介绍的算法是线性回归算法。它是机器学习算法中非常基本的一个算法。这里就不对它进行过多的介绍了,之后应该会写一个博客来叙述。

首先给出一个示例代码,使得有一个整体的映像。

#include <shark/Data/Csv.h>
#include <shark/ObjectiveFunctions/Loss/SquaredLoss.h>
#include <shark/Algorithms/Trainers/LinearRegression.h>

#include <iostream>

using namespace shark;
using namespace std;

int main(int argc, char **argv) {
     if(argc < 3) {
         cerr << "usage: " << argv[0] << " (file with inputs/independent variables) (file with outputs/dependent variables)" << endl;
         exit(EXIT_FAILURE);
     }
     Data<RealVector> inputs;
     Data<RealVector> labels;
     try {
         importCSV(inputs, argv[1], ' ');
     } 
     catch (...) {
         cerr << "unable to read input data from file " <<  argv[1] << endl;
         exit(EXIT_FAILURE);
     }

     try {
         importCSV(labels, argv[2]);
     }
     catch (...) {
         cerr << "unable to read labels from file " <<  argv[2] << endl;
         exit(EXIT_FAILURE);
     }

     RegressionDataset data(inputs, labels);

     // trainer and model
     LinearRegression trainer;
     LinearModel<> model;

     // train model
     trainer.train(model, data);

     // show model parameters
     cout << "intercept: " << model.offset() << endl;
     cout << "matrix: " << model.matrix() << endl;

     SquaredLoss<> loss;
     Data<RealVector> prediction = model(data.inputs()); 
     cout << "squared loss: " << loss(data.labels(), prediction) << endl;
}

首先读取算法所需要的数据集,这里是存储在LabeledData所特化的RegressionDataset中。之后就是初始化算法所对应的模型类,以及算法的训练方法类。利用训练方法类对训练数据进行训练,将训练所得的参数写回到对应的模型中。这里的prediction就是对于数据的预测值。模型重载了括号运算符,里面包含的内容是eval函数,就是计算其输出值。最后利用了平方损失函数来衡量模型的性能。

LinearModel类

Shark中将线性回归算法归于线性模型这一大类中。线性模型是使用线性函数 f(x)=Ax+b 来进行预测的。存在两个特殊的情况是:一是输出可能只是一个单独的数;二是,偏移b可能会被省略。

该文件位于<include/shark/Models/LinearModel.h>中。

template <class InputType = RealVector>
class LinearModel : public AbstractModel<InputType,RealVector>
{
private:
    typedef AbstractModel<InputType,RealVector> base_type;
    typedef LinearModel<InputType> self_type;

    RealMatrix m_matrix; // 权值矩阵
    RealVector m_offset; // 偏置向量
public:
    typedef typename base_type::BatchInputType BatchInputType;
    typedef typename base_type::BatchOutputType BatchOutputType;

    Lin
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值