adaboost算法的主要原理是训练若干个弱分类器,根据训练结果赋予它们不同的权值,最后再将这些弱分类器组合起来,形成一个强分类器,adaboost的基本原理在http://wenku.baidu.com/view/49478920aaea998fcc220e98.html###中已经有很详细的描述
这里使用上一篇博客中的感知器算法作为弱分类器,代码如下:
首先是adaboost算法的结果类
/**
*
* @author zhenhua.chen
* @Description: adboost算法的结果类,包括弱分类器的集合和每个弱分类器的权重
* @date 2013-3-8 下午3:14:58
*
*/
public class AdboostResult {
private ArrayList> weakClassifierSet;
private ArrayList classifierWeightSet;
public ArrayList> getWeakClassifierSet() {
return weakClassifierSet;
}
public void setWeakClassifierSet(ArrayList> weakClassifierSet) {
this.weakClassifierSet = weakClassifierSet;
}
public ArrayList getClassifierWeightSet() {
return classifierWeightSet;
}
public void setClassifierWeightSet(ArrayList classifierWeightSet) {
this.classifierWeightSet = classifierWeightSet;
}
}
adaboost算法:
/**
* http://wenku.baidu.com/view/49478920aaea998fcc220e98.html
* @author zhenhua.chen
* @Description: TODO
* @date 2013-3-8 下午3:09:36
*
*/
public class AdaboostAlgorithm {
private static final int T = 30; // 迭代次数
PerceptronApproach pa = new PerceptronApproach(); // 弱分类器
/**
*
* @Title: adaboostClassify
* @Description: 通过训练集计算出组合分类器
* @return AdboostResult
* @throws
*/
public AdboostResult adaboostClassify(ArrayList> dataSet) {
AdboostResult res = new AdboostResult();
int dataDimension;
if(null != dataSet && dataSet.size() > 0) {
dataDimension = dataSet.get(0).size();
} else {
return null;
}
// 为每条数据的权重赋初值
ArrayList dataWeightSet = new ArrayList();
for(int i = 0; i < dataSet.size(); i ++) {
dataWeightSet.add(1.0 / (double)dataSet.size());
}
// 存储每个弱分类器的权重
ArrayList classifierWeightSet = new ArrayList();
// 存储每个弱分类器
ArrayList> weakClassifierSet = new ArrayList>();
for(int i = 0; i < T; i++) {
// 计算弱分类器
ArrayList sensorWeightVector = pa.getWeightVector(dataSet, dataWeightSet);
weakClassifierSet.add(sensorWeightVector);
// 计算弱分类器误差
double error = 0; //分类数
int rightClassifyNum = 0;
ArrayList cllassifyResult = new ArrayList();
for(int j = 0; j < dataSet.size(); j++) {
double result = 0;
for(int k = 0; k < dataDimension - 1; k++) {
result += dataSet.get(j).get(k) * sensorWeightVector.get(k);
}
result += sensorWeightVector.get(dataDimension - 1);
if(result < 0) { // 说明预测错误
error += dataWeightSet.get(j);
cllassifyResult.add(-1d);
} else{
cllassifyResult.add(1d);
rightClassifyNum++;
}
}
System.out.println("总数:" + dataSet.size() + "正确预测数" + rightClassifyNum);
if(dataSet.size() == rightClassifyNum) {
classifierWeightSet.clear();
weakClassifierSet.clear();
classifierWeightSet.add(1.0);
weakClassifierSet.add(sensorWeightVector);
break;
}
// 更新数据集中每条数据的权重并归一化
double dataWeightSum = 0;
for(int j = 0; j < dataSet.size(); j++) {
dataWeightSet.set(j, dataWeightSet.get(j) * Math.pow(Math.E, (-1) * 0.5 * Math.log((1 - error) / error) * cllassifyResult.get(j))); // 按照http://wenku.baidu.com/view/49478920aaea998fcc220e98.html,更新的权重少除一个常数
dataWeightSum += dataWeightSet.get(j);
}
for(int j = 0; j < dataSet.size(); j++) {
dataWeightSet.set(j, dataWeightSet.get(j) / dataWeightSum);
}
// 计算次弱分类器的权重
double currentWeight = (0.5 * Math.log((1 - error) / error));
classifierWeightSet.add(currentWeight);
System.out.println("classifier weight: " + currentWeight);
}
res.setClassifierWeightSet(classifierWeightSet);
res.setWeakClassifierSet(weakClassifierSet);
return res;
}
/**
*
* @Title: computeResult
* @Description: 计算输入数据的类别
* @return double
* @throws
*/
public int computeResult(ArrayList data, AdboostResult classifier) {
double result = 0;
int dataSize = data.size();
ArrayList> weakClassifierSet = classifier.getWeakClassifierSet();
ArrayList classifierWeightSet = classifier.getClassifierWeightSet();
for(int i = 0; i < weakClassifierSet.size(); i++) {
for(int j = 0; j < dataSize; j++) {
result += weakClassifierSet.get(i).get(j) * data.get(j) * classifierWeightSet.get(i);
}
result += weakClassifierSet.get(i).get(dataSize);
}
if(result > 0) {
return 1;
} else {
return -1;
}
}
测试类:
public static void main(String[] args) {
/**
* 测试数据,产生两类随机数据一类位于圆内,另一类位于包含小圆的大圆内,成环状
* 小圆半径为1,大圆半径为2,公共圆心位于(2, 2)内
*/
final int SMALL_CIRCLE_NUM = 24;
final int RING_NUM = 34;
ArrayList> dataSet = new ArrayList>();
// 产生小圆数据
for(int i = 0; i < SMALL_CIRCLE_NUM; i++) {
double x = 1 + Math.random() * 2; // 1到3的随机数
double y = 1 + Math.random() * 2; // 1到3的随机数
if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内
ArrayList smallCircle = new ArrayList();
smallCircle.add(x);
smallCircle.add(y);
smallCircle.add(1d); // 列别1
dataSet.add(smallCircle);
}
}
// 产生外围环形数据
for(int i = 0; i < RING_NUM; i++) {
double x1 = Math.random() * 4;
double y1 = Math.random() * 4;
if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) { //说明位于环形区域内
ArrayList ring = new ArrayList();
ring.add(-x1);
ring.add(-y1);
ring.add(-1d); // 列别2
dataSet.add(ring);
}
}
AdaboostAlgorithm algo = new AdaboostAlgorithm();
AdboostResult result = algo.adaboostClassify(dataSet);
// 产生测试数据
for(int i = 0; i < 10; i++) {
ArrayList testData = new ArrayList();
double x1 = Math.random() * 4;
double y1 = Math.random() * 4;
if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) {
testData.add(x1);
testData.add(y1);
}
//double x = 1 + Math.random() * 2; // 1到3的随机数
//double y = 1 + Math.random() * 2; // 1到3的随机数
//if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内
//testData.add(x);
//testData.add(y);
//}
algo.computeResult(testData, result);
System.out.println(algo.computeResult(testData, result));
}
}