GridSearch(源码)
主要参数设置:
m_MinX m_MaxX m_StepX m_LabelX m_X_Basem_MinY m_MaxYm_StepY m_LabelY m_Y_Base
classifier evaluation等等
主要函数:
public void buildClassifier(Instances data){...}
|
protected PointDouble findBest(){...}
|
protected PointDouble determineBestInGrid(Grid grid, Instances inst, int cv){...}
|
public EvaluationTask(GridSearch owner, SetupGenerator generator,Instances inst, PointDouble values, int folds, int eval) {...}
buildClassifier-> findBest()->determineBestInGrid->EvaluationTask
重要的程序:
findBest :
result = determineBestInGrid(m_Grid, sample, 2);
确定了m_MinX等参数后可以生成一个网格grid,对网格中每一对点对(x,y)先用2折交叉验证选出其中performance最好的那一对点对,即result。
判断result在网格中的位置center是否在边界上,如果在边界上并且可以扩展的话,得到新的center,然后再以新的center为中心的邻域组成的grid进行10折交叉验证,找出
其中更最优的点对result。如果新的result和旧的相同,那么退出,否则继续上述过程,具体程序如下:
findBest :
finished=false;
if (!finished) {
do {
iteration++;
resultOld = (PointDouble) result.clone();
center = m_Grid.getLocation(result); //获得在grid中的位置
if (m_Grid.isOnBorder(center)) {
log("Center is on border of grid.");
if (getGridIsExtendable()) {
if (m_GridExtensionsPerformed == getMaxGridExtensions()) {
log("Maximum number of extensions reached!\n");
finished = true;
} else {
m_GridExtensionsPerformed++;
m_Grid = m_Grid.extend(result); //扩展
center = m_Grid.getLocation(result);
log("Extending grid (" + m_GridExtensionsPerformed + "/"
+ getMaxGridExtensions() + "):\n" + m_Grid + "\n");
}
} else {
finished = true;
}
}
if (!finished) {
neighborGrid = m_Grid.subgrid((int) center.getY() + 1,
(int) center.getX() - 1, (int) center.getY() - 1,
(int) center.getX() + 1);
result = determineBestInGrid(neighborGrid, sample, 10);
log("\nResult of Step 2/Iteration " + (iteration) + ":\n" + result);
finished = m_UniformPerformance;
if (result.equals(resultOld)) {
finished = true;
log("\nNo better point found.");
}
}
} while (!finished);
}
determineBestInGrid:
Collections.sort(m_Performances, new PerformanceComparator(m_Evaluation)); //排序
result = m_Performances.get(m_Performances.size() - 1).getValues(); //选择最大值所对应的点对
EvaluationTask:
x = m_Generator.evaluate(m_Values.getX(), true); //计算x的值
y = m_Generator.evaluate(m_Values.getY(), false); //计算y的值
classifier = (Classifier) m_Generator.setup(m_Classifier, x, y);
eval = new Evaluation(data);
eval.crossValidateModel(classifier, data, m_Folds, new Random(m_Owner.getSeed())); //交叉验证
performance = new Performance(m_Values, eval);
m_Owner.addPerformance(performance, m_Folds); //
整个程序最主要就是这几个函数。