这篇文章介绍的是LMT的核心 LogisticBase & LMTNode 。
LogisticBase 是基类,可以说大部分的LMT核心功能都在这里实现。我们需要一个一个分析。
看核心函数之前,先看几个核心数组。
double[][] trainYs = getYs(m_train);
double[][] trainFs = getFs(m_numericData);
double[][] probs = getProbs(trainFs);
这三个二维数组都是(实例个数)*(类属性个数) 维的。
我们举例说明,比如现在有两个实例 a,b,类属性有0,1两个值,其中a的类属性值为1,b的类属性值为0,那么:
trainYs存储的是每一个实例真是分布情况,即 [0,1],[1,0]
trainFs存储的是按照线性规划计算出来的值,初始化为1/J,(J为类属性值的个数),这里初始化就是[0.5,0.5],[0.5,0.5]
probs 存储的是e(f[i,j]-max(trainFs[i])/sum ,翻译一下:每一个实例将按照线性规划计算出来的每个属性值减去最大的值,然后作为e的指数,最后归一化。这里能看出,属性值最大的会归一为1。
接下来看一个核心函数 performIteration(...);
函数分为三个过程。
1.对所有的属性,计算z-value,w-value,更新w,学习SimpleLinearRegression.如果无法学习到合适的SimplerLinearRegression,直接返回false。
2.如果所有的属性都学习了,则根据SimpleLinearRegression更新trainFs.
3.根据trainFs更新probs。
第一个过程:
for (int j = 0; j < m_numClasses; j++) {
// Keep track of sum of weights
double[] weights = new double[trainNumeric.numInstances()];
double weightSum = 0.0;
//make copy of data (need to save the weights)
Instances boostData = new Instances(trainNumeric);
for (int i = 0; i < trainNumeric.numInstances(); i++) {
//compute response and weight
/**
* 针对每个实例,计算z-value.计算w-value
*/
double p = probs[i][j];
double actual = trainYs[i][j];
double z = getZ(actual, p);
double w = (actual - p) / z;
/**
* 更新weight
*/
//set values for instance
Instance current = boostData.instance(i);
current.setValue(boostData.classIndex(), z);
current.setWeight(current.weight() * w);
weights[i] = current.weight();
weightSum += current.weight();
}
Instances instancesCopy = new Instances(boostData);
/**
* 这里我不太理解,也许是其他论文里的过程。等看到了再来阐述
*/
if (weightSum > 0) {