Shark源码分析(二):模型与训练方法
之前两篇博客都说的是整个架构底层的东西,如何来存放输入的数据,以便于在计算时进行加速。而这一篇呢,会向上走一点,来看一看如何将模型进行抽象。
在『统计机器学习』中有提到这样一句话,『统计学习方法是由模型、策略和算法构成的』。Shark正是按照这种划分来构建整个库的架构。在监督学习过程中,模型就是所要学习的条件概率分布或是决策函数。Shark是将几个比较类似的算法,例如LinearRegression, LassoRegression,将其中的共同点抽象出来,例如权值向量和偏置,形成一个共同的模型,LinearModel。
有了模型的假设空间,统计学习接着需要考虑的是按照什么样的准则学习或选择最优模型。统计学习的目标在于从假设空间中选取最优的模型。这实际上就是目标函数的选择。在Shark中也将目标函数单独出来成一个基类为AbstractObjectiveFunction的类系,方便用户自由选择模型所使用的目标函数。
算法是指学习模型的具体计算方法。统计学习基于训练数据集,根据学习策略,从假设空间中选择最优模型,最后考虑用什么样的计算方法求解最优模型。
在此基础之上,将所有模型和训练方法的共同点再抽象出来,形成最高层的基类。接下来,我们就来看一下这两个基类。
AbstractModel类
这是所有模型的基类。这个类的定义在<include/shark/Models/AbstractModel.h>
文件中。
template<class InputTypeT, class OutputTypeT>
class AbstractModel : public IParameterizable, public INameable, public ISerializable
{
public:
typedef InputTypeT InputType;
typedef OutputTypeT OutputType;
typedef OutputType result_type;
typedef typename Batch<InputType>::type BatchInputType;
typedef typename Batch<OutputType>::type BatchOutputType;
AbstractModel() { }
virtual ~AbstractModel() { }
//这个枚举里存储的信息包括相关参数、input是否存在一阶、二阶导数
//这些信息应该是用在那些利用到梯度的训练方法中
//可能就会有人有疑问,训练方法不是已经与模型分开了吗
//但是模型的目标函数是定义在该类中的,稍后会有介绍
enum Feature {
HAS_FIRST_PARAMETER_DERIVATIVE = 1,
HAS_SECOND_PARAMETER_DERIVATIVE = 2,
HAS_FIRST_INPUT_DERIVATIVE = 4,
HAS_SECOND_INPUT_DERIVATIVE = 8,
IS_SEQUENTIAL = 16