Shark源码分析(五):线性回归算法与Lasso回归
为什么上一篇还是三,这一篇就跳到五了呢?其实我们原来提到过:
方法=模型+策略+算法
这里的模型与算法我们之前都已经提到过了,虽然只是介绍了一个基类,并没有涉及到其具体的实现。在这里我们就会揭开其真正面目了。『策略』我们还没有介绍过,其实就是目标函数,在前面一些较为简单的算法中并没有涉及到这块。为了整个逻辑的完整性,我还是打算将其放在前面来介绍。
这里我们所介绍的算法是线性回归算法。它是机器学习算法中非常基本的一个算法。这里就不对它进行过多的介绍了,之后应该会写一个博客来叙述。
首先给出一个示例代码,使得有一个整体的映像。
#include <shark/Data/Csv.h>
#include <shark/ObjectiveFunctions/Loss/SquaredLoss.h>
#include <shark/Algorithms/Trainers/LinearRegression.h>
#include <iostream>
using namespace shark;
using namespace std;
int main(int argc, char **argv) {
if(argc < 3) {
cerr << "usage: " << argv[0] << " (file with inputs/independent variables) (file with outputs/dependent variables)" << endl;
exit(EXIT_FAILURE);
}
Data<RealVector> inputs;
Data<RealVector> labels;
try {
importCSV(inputs, argv[1], ' ');
}
catch (...) {
cerr << "unable to read input data from file " << argv[1] << endl;
exit(EXIT_FAILURE);
}
try {
importCSV(labels, argv[2]);
}
catch (...) {
cerr << "unable to read labels from file " << argv[2] << endl;
exit(EXIT_FAILURE);
}
RegressionDataset data(inputs, labels);
// trainer and model
LinearRegression trainer;
LinearModel<> model;
// train model
trainer.train(model, data);
// show model parameters
cout << "intercept: " << model.offset() << endl;
cout << "matrix: " << model.matrix() << endl;
SquaredLoss<> loss;
Data<RealVector> prediction = model(data.inputs());
cout << "squared loss: " << loss(data.labels(), prediction) << endl;
}
首先读取算法所需要的数据集,这里是存储在LabeledData所特化的RegressionDataset中。之后就是初始化算法所对应的模型类,以及算法的训练方法类。利用训练方法类对训练数据进行训练,将训练所得的参数写回到对应的模型中。这里的prediction就是对于数据的预测值。模型重载了括号运算符,里面包含的内容是eval函数,就是计算其输出值。最后利用了平方损失函数来衡量模型的性能。
LinearModel类
Shark中将线性回归算法归于线性模型这一大类中。线性模型是使用线性函数 f(x)=Ax+b 来进行预测的。存在两个特殊的情况是:一是输出可能只是一个单独的数;二是,偏移b可能会被省略。
该文件位于<include/shark/Models/LinearModel.h>
中。
template <class InputType = RealVector>
class LinearModel : public AbstractModel<InputType,RealVector>
{
private:
typedef AbstractModel<InputType,RealVector> base_type;
typedef LinearModel<InputType> self_type;
RealMatrix m_matrix; // 权值矩阵
RealVector m_offset; // 偏置向量
public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
Lin