关闭

纯手工实现adaboost(adaboost源码)

标签: adaboost
162人阅读 评论(0) 收藏 举报
分类:

adaboost --ensemble learning的一种 

说明几点 测试集ionosphere.data可以自行下载

检查过几遍 难免疏漏 有问题欢迎提出

感觉大家对adaboost的实现兴趣都不大

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #include "stdio.h"  
  2. #include "assert.h"  
  3. #include "string.h"  
  4. #include "stdlib.h"  
  5. #include "math.h"  
  6. #include "adaboost.h"  
  7.   
  8. #define DATA_NAME "..//dataset//ionosphere.data"  
  9.   
  10. SampleHeader sampleHeader;  
  11. IdxHeader idx;  
  12.   
  13.   
  14. //==================================================================  
  15. //函数名:  sort  
  16. //作者:    qiurenbo  
  17. //日期:    2014-11-25  
  18. //功能:    冒泡排序  
  19. //输入参数:double a[]  
  20. //          n   数组长度  
  21. //返回值:  无  
  22. //修改记录:  
  23. //==================================================================  
  24. void sort(double a[], int n)  
  25. {  
  26.     double tmp;  
  27.     for (int i = 0; i < n - 1; i++)  
  28.     {  
  29.         for (int j = 0; j < n - i - 1; j++)  
  30.         {  
  31.             if (a[j] > a[j+1])  
  32.             {  
  33.                 tmp = a[j];  
  34.                 a[j] = a[j+1];  
  35.                 a[j+1] = tmp;  
  36.             }  
  37.         }  
  38.     }  
  39. }  
  40.   
  41. //==================================================================  
  42. //函数名:  countFeature  
  43. //作者:    qiurenbo  
  44. //日期:    2014-11-25  
  45. //功能:    计算一行上有多少特征  
  46. //输入参数:char* buf 文本的一行  
  47. //返回值:  特征个数  
  48. //修改记录:  
  49. //==================================================================  
  50. int countFeature(const char* buf)  
  51. {  
  52.     const char* p = buf;  
  53.     int   cnt = 0;  
  54.     while(*p != NULL)  
  55.     {  
  56.         if (*p == ',')  
  57.             cnt++;  
  58.         p++;  
  59.     }  
  60.   
  61.     return cnt;  
  62. }  
  63.   
  64. //==================================================================  
  65. //函数名:  setFeature  
  66. //作者:    qiurenbo  
  67. //日期:    2014-11-25  
  68. //功能:    将读取的特征分配到正负样本结构体上  
  69. //输入参数:char* buf 文本的一行  
  70. //返回值:  无  
  71. //修改记录:  
  72. //==================================================================  
  73. void setFeature(char* buf)  
  74. {  
  75.     int i = 0;  
  76.     struct Sample sample;  
  77.   
  78.     char*p = strtok(buf, ",");  
  79.   
  80.     sample.feature[i++] = atof(p);  
  81.       
  82.     while(1)  
  83.     {  
  84.         if (*p != 'g' && *p != 'b')  
  85.             sample.feature[i++] = atof(p);  
  86.           
  87.         else  
  88.             break;  
  89.           
  90.         p = strtok(NULL, ",");  
  91.     }  
  92.   
  93.     if (*p == 'g')  
  94.         sample.indicate = 1;  
  95.     else if (*p == 'b')  
  96.         sample.indicate = -1;  
  97.     else  
  98.         assert(0);  
  99.   
  100.     //idx 每行是所有样本的同一特征  
  101.     for(i = 0; i < sampleHeader.featureNum; i++)  
  102.         idx.feature[i][idx.samplesNum] = sample.feature[i];  
  103.   
  104.     sampleHeader.samples[sampleHeader.samplesNum] = sample;   
  105.     sampleHeader.samplesNum++;  
  106.     idx.samplesNum++;  
  107.       
  108. };  
  109.   
  110.   
  111. //==================================================================  
  112. //函数名:  loadData  
  113. //作者:    qiurenbo  
  114. //日期:    2014-11-25  
  115. //功能:    读取文本数据  
  116. //输入参数:char* buf 文本的一行  
  117. //返回值:  无  
  118. //修改记录:  
  119. //==================================================================  
  120. void loadData()  
  121. {  
  122.     FILE *fp = NULL;   
  123.     char buf[1000];    
  124.     int featureCnt = 0;  
  125.     double* featrue = NULL;  
  126.     double* featruePtr = NULL;  
  127.     int i = 0;  
  128.   
  129.     fp = fopen(DATA_NAME, "r");  
  130.     assert(fp);  
  131.   
  132.   
  133.     fgets(buf, 1000, fp);  
  134.     idx.featureNum = sampleHeader.featureNum = countFeature(buf);  
  135.       
  136.     setFeature(buf);  
  137.     //统计样本数  
  138.       
  139.     while(!feof(fp))  
  140.     {  
  141.           
  142.         fgets(buf, 1000, fp);  
  143.         setFeature(buf);  
  144.           
  145.           
  146.     }  
  147.       
  148.     fclose(fp);  
  149.       
  150.     for (i = 0; i < idx.featureNum; i++)  
  151.         sort(idx.feature[i], idx.samplesNum);  
  152.   
  153.   
  154.       
  155.   
  156.   
  157. }  
  158.   
  159.   
  160. //==================================================================  
  161. //函数名:  CreateStump  
  162. //作者:    qiurenbo  
  163. //日期:    2014-11-26  
  164. //功能:    创建一个stump分类器  
  165. //输入参数:无  
  166. //返回值:  stump  
  167. //修改记录:  
  168. //==================================================================  
  169. Stump CreateStump()  
  170. {  
  171.     int i,j,k;  
  172.     Stump stump;  
  173.     double min = 0xffffffff;  
  174.     double err = 0;  
  175.     double flipErr = 0;  
  176.       
  177.     double feature;  
  178.     int indicate;  
  179.     double weight;  
  180.     double pre;  
  181.     for( i = 0; i < idx.featureNum; i++)  
  182.     {  
  183.         pre = 0xffffffff;  
  184.         for(j = 0; j < idx.samplesNum; j++)  
  185.         {  
  186.           
  187.             err = 0;  
  188.             double rootFeature = idx.feature[i][j];  
  189.   
  190.             //跳过相同的值  
  191.             if (pre == rootFeature)  
  192.                 continue;  
  193.   
  194.   
  195.             for (k = 0; k < sampleHeader.samplesNum; k++)  
  196.             {  
  197.                 feature = sampleHeader.samples[k].feature[i];  
  198.                 indicate = sampleHeader.samples[k].indicate;  
  199.                 weight = sampleHeader.samples[k].weight;  
  200.                 if ((feature <  rootFeature  && indicate != 1) ||\  
  201.                     (feature >= rootFeature && indicate != -1)   
  202.                     )  
  203.                     err += weight;    
  204.   
  205.             }  
  206.   
  207.             //左边是1,还是右边是1,选取error最小的组合  
  208.             flipErr = 1 - err;  
  209.             err = err < flipErr ? err:flipErr;  
  210.               
  211.             //选取具有最小err的特征rootFeature  
  212.             if (err < min)  
  213.             {  
  214.                 min = err;  
  215.                 stump.fIdx = i;  
  216.                 stump.ft = rootFeature;  
  217.                 if (err < flipErr)  
  218.                 {  
  219.                     stump.left = 1;  
  220.                     stump.right = -1;  
  221.                 }             
  222.                 else  
  223.                 {  
  224.                       
  225.                     stump.left = -1;  
  226.                     stump.right = 1;  
  227.                 }  
  228.             }  
  229.       
  230.             pre = rootFeature;  
  231.         }  
  232.     }  
  233.       
  234.       
  235.     stump.alpha = 0.5*log(1.0/min - 1);  
  236.     return stump;  
  237. }  
  238.   
  239.   
  240. //==================================================================  
  241. //函数名:  reSetWeight  
  242. //作者:    qiurenbo  
  243. //日期:    2014-11-26  
  244. //功能:    每次迭代重新调整权重  
  245. //输入参数:stump  
  246. //返回值:  无  
  247. //修改记录:  
  248. //==================================================================  
  249. void reSetWeight(struct Stump stump)  
  250. {  
  251.     int i;  
  252.     double z = 0;  
  253.   
  254.     //计算规范化因子z  
  255.     for(i = 0; i < sampleHeader.samplesNum; i++)  
  256.     {  
  257.         double feature = (sampleHeader.samples[i]).feature[stump.fIdx];  
  258.         double rs = feature < stump.ft ? stump.left:stump.right;  
  259.         rs = stump.alpha * rs * sampleHeader.samples[i].indicate;  
  260.   
  261.         z += sampleHeader.samples[i].weight * exp(-1.0 * rs);  
  262.     }  
  263.       
  264.   
  265.       
  266.     //调整各个样本的权值  
  267.     for(i = 0; i < sampleHeader.samplesNum; i++)  
  268.     {  
  269.         double feature = (sampleHeader.samples[i]).feature[stump.fIdx];  
  270.         double rs = feature < stump.ft ? stump.left:stump.right;  
  271.         rs = stump.alpha * rs * sampleHeader.samples[i].indicate;  
  272.           
  273.         sampleHeader.samples[i].weight= sampleHeader.samples[i].weight * exp(-1.0 * rs) / z;  
  274.           
  275.           
  276.     }  
  277.   
  278. #ifdef DEBUG  
  279.   
  280.     //debug  
  281.     for(i = 0; i < 10; i++)  
  282.     {  
  283.         double feature = (sampleHeader.samples[i]).feature[stump.fIdx];  
  284.         double rs = feature < stump.ft ? stump.left:stump.right;  
  285.         rs = stump.alpha * rs * sampleHeader.samples[i].indicate;  
  286.         printf("weight:%lf, rs:%lf\n",sampleHeader.samples[i].weight , rs);  
  287.     }  
  288.       
  289.   
  290.     //getchar();  
  291.       
  292. #endif  
  293. }  
  294.   
  295. //==================================================================  
  296. //函数名:  AdaBoost  
  297. //作者:    qiurenbo  
  298. //日期:    2014-11-26  
  299. //功能:    adaboost训练弱分类器  
  300. //输入参数:interation  迭代次数  
  301. //返回值:  无  
  302. //修改记录:  
  303. //==================================================================  
  304. void AdaBoost(int interation)  
  305. {  
  306.     int i;  
  307.     struct ClassifierHeader head;  
  308.     struct Classifier* pCls = NULL;  
  309.     struct Classifier* tmp = NULL;  
  310.     head.classifierNum = interation;  
  311.   
  312.     loadData();  
  313.       
  314.     //设置初始样本权重  
  315.     for(i = 0; i < sampleHeader.samplesNum; i++)  
  316.         sampleHeader.samples[i].weight = 1.0 / sampleHeader.samplesNum;  
  317.   
  318.       
  319.     head.classifier = (struct Classifier*)malloc(sizeof(struct Classifier));  
  320.     pCls = head.classifier;  
  321.     pCls->stump = CreateStump();  
  322.     reSetWeight(pCls->stump);  
  323.     //printf("completed:%lf%%\r", 1.0/head.classifierNum*100);  
  324.     printf("+-----------+--+-------+\n" );  
  325.     printf("|   alpha   |id|  ft   |\n");  
  326.     printf("+-----------+--+-------+\n" );  
  327.     printf("|%.9lf|%2d|%+.4lf|\n", pCls->stump.alpha, pCls->stump.fIdx, pCls->stump.ft);  
  328.     printf( "+-----------+--+-------+\n" );  
  329.     for (i = 1; i < head.classifierNum; i++)  
  330.     {     
  331.       
  332.         pCls = pCls->next = (struct Classifier*)malloc(sizeof(struct Classifier));     
  333.       
  334.         pCls->stump = CreateStump();  
  335.         reSetWeight(pCls->stump);  
  336.         printf("|%.9lf|%2d|%+.4lf|\n", pCls->stump.alpha, pCls->stump.fIdx, pCls->stump.ft);  
  337.         printf( "+-----------+--+-------+\n" );  
  338.         //printf("completed:%lf%%\r", 1.0*(i+1)/head.classifierNum*100);  
  339.           
  340.     }  
  341.   
  342.       
  343.     printf("\n");  
  344.   
  345.     for(i = 0, pCls = head.classifier; i < head.classifierNum; i++)  
  346.     {  
  347.         tmp = pCls;  
  348.         pCls = tmp->next;  
  349.         free(tmp);  
  350.     }  
  351.   
  352. }  
  353.   
  354.   
  355. void main()  
  356. {  
  357.   
  358.     AdaBoost(100);  
  359. }  

adaboost.h

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #ifndef _ADABOOST_H_  
  2. #define _ADABOOST_H_   
  3. #define MAX_FEATURE 100  
  4. #define MAX_SAMPLES 500  
  5. //#define DEBUG    
  6. struct Sample  
  7. {  
  8.     double weight;  
  9.     double feature[MAX_FEATURE];  
  10.     int    indicate;  
  11. };    
  12. struct SampleHeader  
  13. {  
  14.     int samplesNum;  
  15.     int featureNum;  
  16.       
  17.     //double feature[MAX_SAMPLES][MAX_FEATURE];  
  18.     struct Sample samples[MAX_SAMPLES];  
  19.       
  20. };  
  21.   
  22. struct Stump  
  23. {  
  24.     int left;  
  25.     int right;  
  26.     double alpha;  
  27.     int fIdx;  
  28.     double ft;  
  29. };  
  30.   
  31. struct Classifier  
  32. {  
  33.     struct Stump stump;  
  34.     struct Classifier* next;  
  35. };  
  36.   
  37.   
  38. struct ClassifierHeader  
  39. {  
  40.     int classifierNum;  
  41.     struct Classifier* classifier;  
  42. };  
  43. struct IdxHeader  
  44. {  
  45.     int samplesNum;  
  46.     int featureNum;  
  47.       
  48.     double feature[MAX_FEATURE][MAX_SAMPLES];  
  49.       
  50. };  
  51.   
  52. #endif 
0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:80989次
    • 积分:1119
    • 等级:
    • 排名:千里之外
    • 原创:21篇
    • 转载:58篇
    • 译文:0篇
    • 评论:10条
    最新评论