Shark源码分析(四):目标函数及其优化
我们之前反反复复地强调过,Shark的设计策略是:方法 = 模型 + 策略 + 算法。这里的策略指的就是目标函数。我们的目标是从假设空间中选取最优的模型,需要考虑的是按照什么样的准则学习或选择最优的模型。选择一个好的目标函数形式,也是算法最终取得好的效果的一个重要组成部分。
在确定了目标函数之后,我们还需要考虑的就是如何对其进行优化。在Shark中,将定义目标函数与对目标函数优化区分开来,分别定义了两个基类构成了两套继承体系。
AbstractObjectiveFunction类
该类是所有优化以及学习问题的目标函数的基类。对于那些需要有目标函数的算法都可以使用该类。因为一个目标函数不可能满足所有的需求,所以该类针对不同的优化方法提供了丰富的接口。该类定义在<include/Shark/ObjectiveFunctions/AbstractObjectiveFunction.h>
。
template <typename PointType, typename ResultT>
class AbstractObjectiveFunction : public INameable{
public:
typedef PointType SearchPointType;
typedef ResultT ResultType;
typedef SearchPointType FirstOrderDerivative;
struct SecondOrderDerivative {
RealVector gradient;
RealMatrix hessian;
};
/// \brief List of features that are supported by an implementation.
enum Feature {
HAS_VALUE = 1,
HAS_FIRST_DERIVATIVE = 2, //是否能求一阶导数
HAS_SECOND_DERIVATIVE = 4, //是否能求二阶导数
CAN_PROPOSE_STARTING_POINT = 8, //函数能否提供搜索的起点
IS_CONSTRAINED_FEATURE = 16, //是否是一个带约束的优化问题
HAS_CONSTRAINT_HANDLER = 32, //能否提供可行域的信息
CAN_PROVIDE_CLOSEST_FEASIBLE = 64,
IS_THREAD_SAFE = 128 //能否在多线程环境下使用
};
SHARK_FEATURE_INTERFACE;
//以下函数都是用于返回目标函数特征信息的
bool hasValue()const{
return m_features & HAS_VALUE;
}
bool hasFirstDerivative()const{
return m_features & HAS_FIRST_DERIVATIVE;
}
bool hasSecondDerivative()const{
return m_features & HAS_SECOND_DERIVATIVE;
}
bool canProposeStartingPoint()const{
return m_features & CAN_PROPOSE_STARTING_POINT;
}
bool isConstrained()const{
return m_features & IS_CONSTRAINED_FEATURE;
}
bool hasConstraintHandler()const{
return m_features & HAS_CONSTRAINT_HANDLER;
}
bool canProvideClosestFeasible()const{
return m_features & CAN_PROVIDE_CLOSEST_FEASIBLE;
}
bool isThreadSafe()const{
return m_features & IS_THREAD_SAFE;
}
AbstractObjectiveFunction():m_evaluationCounter(0) {
m_features |=HAS_VALUE;
}
virtual ~AbstractObjectiveFunction() {}
virtual void init() {
m_evaluationCounter=0;
}
virtual std::size_t numberOfVariables() const=0;
virtual bool hasScalableDimensionality()const{
return false;
}
/// \brief Adjusts the number of variables if the function is scalable.
virtual void setNumberOfVariables( std::size_t numberOfVariables ){
throw SHARKEXCEPTION("dimensionality of function is not scalable");
}
virtual std::size_t numberOfObjectives() const{
return 1;
}
virtual bool hasScalableObjectives()const{
return false;
}
/// \brief Adjusts the number of objectives if the function is scalable.
virtual void setNumberOfObjectives( std::size_t numberOfObjectives ){
throw SHARKEXCEPTION("dimensionality of function is not scaleable");
}
std::size_t evaluationCounter() const {
return m_evaluationCounter;
}
AbstractConstraintHandler<SearchPointType> const& getConstraintHandler()const{
if(m_constraintHandler == NULL)
throw SHARKEXCEPTION("Objective Function does not have an constraint handler!");
return *m_constraintHandler;
}
//判断一个搜索空间中的点是否是可行的
virtual bool isFeasible( const SearchPointType & input) const {
if(hasConstraintHandler()) return getConstraintHandler().isFeasible(input);
if(isConstrained())
throw SHARKEXCEPTION("[AbstractObjectiveFunction::isFasible] not overwritten, even though function is constrained");