Ranklib部分源码分析(LambdaMART+RandomForest)
声明
本文是Ranklib部分源码的分析,参考了RankLib源码分析——guoguo881218的专栏以及Learning to Rank——wowarsenal,在此对原博主表示感谢
关于Ranklib
在The Lemur Project可以下载到Ranklib程序,Ranklib2.1和Ranklib2.3有源码可以下载,Ranklib2.4和Ranklib2.5只有jar文件可以下载。通过jad反编译后可以看到源码,整体结构差别不大。本文以Ranklib2.3为标准进行说明。
主框架(Evlauator.java)
程序主入口main函数
Ranklib程序主入口为ciir.umass.edu.eval.Evaluator
类中main函数.
其中public static void main(String[] args)
函数接收命令行传入参数。
- 首先初始化一些变量并根据传入参数给变量赋值:
for(int i=0;i<args.length;i++)
{
if(args[i].compareTo("-train")==0)
trainFile = args[++i]; //训练集
else if(args[i].compareTo("-ranker")==0)
rankerType = Integer.parseInt(args[++i]); //Rank类型
...
else if(args[i].compareTo("-metric2t")==0)
trainMetric = args[++i]; //训练集Metric
else if(args[i].compareTo("-metric2T")==0)
testMetric = args[++i]; //测试集Metric
...
else if(args[i].compareTo("-validate")==0)
validationFile = args[++i]; //验证集
else if(args[i].compareTo("-test")==0)
{
testFile = args[++i];
testFiles.add(testFile);
} //测试集
...
else if(args[i].compareTo("-save")==0)
Evaluator.modelFile = args[++i]; //模型保存位置
...
else if(args[i].compareTo("-load")==0)
{
savedModelFile = args[++i];
savedModelFiles.add(args[i]);
} //导入模型
...
else if(args[i].compareTo("-rank")==0)
rankFile = args[++i]; //待排序数据
... ... ...
//MART / LambdaMART / Random forest
else if(args[i].compareTo("-tree")==0)
{
LambdaMART.nTrees = Integer.parseInt(args[++i]);
RFRanker.nTrees = Integer.parseInt(args[i]);
} //树的棵树
else if(args[i].compareTo("-leaf")==0)
{
LambdaMART.nTreeLeaves = Integer.parseInt(args[++i]);
RFRanker.nTreeLeaves = Integer.parseInt(args[i]);
} //每棵树叶子结点数
else if(args[i].compareTo("-shrinkage")==0)
{
LambdaMART.learningRate = Float.parseFloat(args[++i]);
RFRanker.learningRate = Float.parseFloat(args[i]);
} //收缩系数
...
//Random forest
else if(args[i].compareTo("-bag")==0)
RFRanker.nBag = Integer.parseInt(args[++i]); //bags数目
- 根据参数变量进行训练
if(nThread == -1)
nThread = Runtime.getRuntime().availableProcessors();
MyThreadPool.init(nThread); //线程池初始化
...
Evaluator e = new Evaluator(rType2[rankerType], trainMetric, testMetric); //根据Rank类型以及训练集、测试集上的评价函数生成Evaluator对象
... ...
RankerFactory rf = new RankerFactory();
rf.createRanker(rType2[rankerType]).printParameters();//根据参数创建Rank对象
...
e.evaluate() //多个实现,针对不同情况进行evaluate
...