adaboost算法代码 java_java实现adaboost算法

这篇博客介绍了如何用Java实现AdaBoost算法,该算法通过训练多个弱分类器并赋予不同权重,组合成一个强分类器。文章详细讲解了算法的基本原理,并提供了感知器作为弱分类器的代码实现,包括AdaboostResult类和AdaboostAlgorithm类的详细代码。此外,还给出了测试数据的生成和分类结果的计算方法。
摘要由CSDN通过智能技术生成

adaboost算法的主要原理是训练若干个弱分类器,根据训练结果赋予它们不同的权值,最后再将这些弱分类器组合起来,形成一个强分类器,adaboost的基本原理在http://wenku.baidu.com/view/49478920aaea998fcc220e98.html###中已经有很详细的描述

这里使用上一篇博客中的感知器算法作为弱分类器,代码如下:

首先是adaboost算法的结果类

/**

*

* @author zhenhua.chen

* @Description: adboost算法的结果类,包括弱分类器的集合和每个弱分类器的权重

* @date 2013-3-8 下午3:14:58

*

*/

public class AdboostResult {

private ArrayList> weakClassifierSet;

private ArrayList classifierWeightSet;

public ArrayList> getWeakClassifierSet() {

return weakClassifierSet;

}

public void setWeakClassifierSet(ArrayList> weakClassifierSet) {

this.weakClassifierSet = weakClassifierSet;

}

public ArrayList getClassifierWeightSet() {

return classifierWeightSet;

}

public void setClassifierWeightSet(ArrayList classifierWeightSet) {

this.classifierWeightSet = classifierWeightSet;

}

}

adaboost算法:

/**

* http://wenku.baidu.com/view/49478920aaea998fcc220e98.html

* @author zhenhua.chen

* @Description: TODO

* @date 2013-3-8 下午3:09:36

*

*/

public class AdaboostAlgorithm {

private static final int T = 30; // 迭代次数

PerceptronApproach pa = new PerceptronApproach(); // 弱分类器

/**

*

* @Title: adaboostClassify

* @Description: 通过训练集计算出组合分类器

* @return AdboostResult

* @throws

*/

public AdboostResult adaboostClassify(ArrayList> dataSet) {

AdboostResult res = new AdboostResult();

int dataDimension;

if(null != dataSet && dataSet.size() > 0) {

dataDimension = dataSet.get(0).size();

} else {

return null;

}

// 为每条数据的权重赋初值

ArrayList dataWeightSet = new ArrayList();

for(int i = 0; i < dataSet.size(); i ++) {

dataWeightSet.add(1.0 / (double)dataSet.size());

}

// 存储每个弱分类器的权重

ArrayList classifierWeightSet = new ArrayList();

// 存储每个弱分类器

ArrayList> weakClassifierSet = new ArrayList>();

for(int i = 0; i < T; i++) {

// 计算弱分类器

ArrayList sensorWeightVector = pa.getWeightVector(dataSet, dataWeightSet);

weakClassifierSet.add(sensorWeightVector);

// 计算弱分类器误差

double error = 0; //分类数

int rightClassifyNum = 0;

ArrayList cllassifyResult = new ArrayList();

for(int j = 0; j < dataSet.size(); j++) {

double result = 0;

for(int k = 0; k < dataDimension - 1; k++) {

result += dataSet.get(j).get(k) * sensorWeightVector.get(k);

}

result += sensorWeightVector.get(dataDimension - 1);

if(result < 0) { // 说明预测错误

error += dataWeightSet.get(j);

cllassifyResult.add(-1d);

} else{

cllassifyResult.add(1d);

rightClassifyNum++;

}

}

System.out.println("总数:" + dataSet.size() + "正确预测数" + rightClassifyNum);

if(dataSet.size() == rightClassifyNum) {

classifierWeightSet.clear();

weakClassifierSet.clear();

classifierWeightSet.add(1.0);

weakClassifierSet.add(sensorWeightVector);

break;

}

// 更新数据集中每条数据的权重并归一化

double dataWeightSum = 0;

for(int j = 0; j < dataSet.size(); j++) {

dataWeightSet.set(j, dataWeightSet.get(j) * Math.pow(Math.E, (-1) * 0.5 * Math.log((1 - error) / error) * cllassifyResult.get(j))); // 按照http://wenku.baidu.com/view/49478920aaea998fcc220e98.html,更新的权重少除一个常数

dataWeightSum += dataWeightSet.get(j);

}

for(int j = 0; j < dataSet.size(); j++) {

dataWeightSet.set(j, dataWeightSet.get(j) / dataWeightSum);

}

// 计算次弱分类器的权重

double currentWeight = (0.5 * Math.log((1 - error) / error));

classifierWeightSet.add(currentWeight);

System.out.println("classifier weight: " + currentWeight);

}

res.setClassifierWeightSet(classifierWeightSet);

res.setWeakClassifierSet(weakClassifierSet);

return res;

}

/**

*

* @Title: computeResult

* @Description: 计算输入数据的类别

* @return double

* @throws

*/

public int computeResult(ArrayList data, AdboostResult classifier) {

double result = 0;

int dataSize = data.size();

ArrayList> weakClassifierSet = classifier.getWeakClassifierSet();

ArrayList classifierWeightSet = classifier.getClassifierWeightSet();

for(int i = 0; i < weakClassifierSet.size(); i++) {

for(int j = 0; j < dataSize; j++) {

result += weakClassifierSet.get(i).get(j) * data.get(j) * classifierWeightSet.get(i);

}

result += weakClassifierSet.get(i).get(dataSize);

}

if(result > 0) {

return 1;

} else {

return -1;

}

}

测试类:

public static void main(String[] args) {

/**

* 测试数据,产生两类随机数据一类位于圆内,另一类位于包含小圆的大圆内,成环状

* 小圆半径为1,大圆半径为2,公共圆心位于(2, 2)内

*/

final int SMALL_CIRCLE_NUM = 24;

final int RING_NUM = 34;

ArrayList> dataSet = new ArrayList>();

// 产生小圆数据

for(int i = 0; i < SMALL_CIRCLE_NUM; i++) {

double x = 1 + Math.random() * 2; // 1到3的随机数

double y = 1 + Math.random() * 2; // 1到3的随机数

if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内

ArrayList smallCircle = new ArrayList();

smallCircle.add(x);

smallCircle.add(y);

smallCircle.add(1d); // 列别1

dataSet.add(smallCircle);

}

}

// 产生外围环形数据

for(int i = 0; i < RING_NUM; i++) {

double x1 = Math.random() * 4;

double y1 = Math.random() * 4;

if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) { //说明位于环形区域内

ArrayList ring = new ArrayList();

ring.add(-x1);

ring.add(-y1);

ring.add(-1d); // 列别2

dataSet.add(ring);

}

}

AdaboostAlgorithm algo = new AdaboostAlgorithm();

AdboostResult result = algo.adaboostClassify(dataSet);

// 产生测试数据

for(int i = 0; i < 10; i++) {

ArrayList testData = new ArrayList();

double x1 = Math.random() * 4;

double y1 = Math.random() * 4;

if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) {

testData.add(x1);

testData.add(y1);

}

//double x = 1 + Math.random() * 2; // 1到3的随机数

//double y = 1 + Math.random() * 2; // 1到3的随机数

//if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内

//testData.add(x);

//testData.add(y);

//}

algo.computeResult(testData, result);

System.out.println(algo.computeResult(testData, result));

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值