CBA算法---基于关联规则进行分类的算法 CBA算法---基于关联规则进行分类的算法CBA算法---基于关联规则进行分类的算法...



关闭

CBA算法---基于关联规则进行分类的算法

标签:机器学习数据挖掘算法数据
823人阅读 评论(0) 收藏 举报
分类:
算法(44) 数据挖掘(32) 机器学习(30)

目录(?)[+]

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。

算法原理

CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:

1、输入数据记录,就是一条条的属性值。

2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。

3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。

4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。

这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 113HighNoFairCLassNo
  3. 211HighNoExcellentCLassNo
  4. 325HighNoFairCLassYes
  5. 445MediumNoFairCLassYes
  6. 550LowYesFairCLassYes
  7. 651LowYesExcellentCLassNo
  8. 730LowYesExcellentCLassYes
  9. 813MediumNoFairCLassNo
  10. 99LowYesFairCLassYes
  11. 1055MediumYesFairCLassYes
  12. 1114MediumYesExcellentCLassYes
  13. 1233MediumNoExcellentCLassYes
  14. 1333HighYesFairCLassYes
  15. 1441MediumNoExcellentCLassNo
属性值对应的数字替换图:

  1. Medium=5,CLassYes=12,Excellent=10,Low=6,Fair=9,CLassNo=11,Young=1,Middle_aged=2,Yes=8,No=7,High=4,Senior=3
体会之后的数据变为了下面的事务项:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 1147911
  3. 21471011
  4. 3247912
  5. 4357912
  6. 5368912
  7. 63681011
  8. 72681012
  9. 8157911
  10. 9168912
  11. 10358912
  12. 111581012
  13. 122571012
  14. 13248912
  15. 143571011
把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以 点击这里进行了解。

算法的代码实现

测试数据就是上面的内容。

CBATool.java:

  1. packageDataMining_CBA;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.HashMap;
  8. importjava.util.regex.Matcher;
  9. importjava.util.regex.Pattern;
  10. importDataMining_CBA.AprioriTool.AprioriTool;
  11. importDataMining_CBA.AprioriTool.FrequentItem;
  12. /**
  13. *CBA算法(关联规则分类)工具类
  14. *
  15. *@authorlyq
  16. *
  17. */
  18. publicclassCBATool{
  19. //年龄的类别划分
  20. publicfinalStringAGE="Age";
  21. publicfinalStringAGE_YOUNG="Young";
  22. publicfinalStringAGE_MIDDLE_AGED="Middle_aged";
  23. publicfinalStringAGE_Senior="Senior";
  24. //测试数据地址
  25. privateStringfilePath;
  26. //最小支持度阈值率
  27. privatedoubleminSupportRate;
  28. //最小置信度阈值,用来判断是否能够成为关联规则
  29. privatedoubleminConf;
  30. //最小支持度
  31. privateintminSupportCount;
  32. //属性列名称
  33. privateString[]attrNames;
  34. //类别属性所代表的数字集合
  35. privateArrayList<Integer>classTypes;
  36. //用二维数组保存测试数据
  37. privateArrayList<String[]>totalDatas;
  38. //Apriori算法工具类
  39. privateAprioriToolaprioriTool;
  40. //属性到数字的映射图
  41. privateHashMap<String,Integer>attr2Num;
  42. privateHashMap<Integer,String>num2Attr;
  43. publicCBATool(StringfilePath,doubleminSupportRate,doubleminConf){
  44. this.filePath=filePath;
  45. this.minConf=minConf;
  46. this.minSupportRate=minSupportRate;
  47. readDataFile();
  48. }
  49. /**
  50. *从文件中读取数据
  51. */
  52. privatevoidreadDataFile(){
  53. Filefile=newFile(filePath);
  54. ArrayList<String[]>dataArray=newArrayList<String[]>();
  55. try{
  56. BufferedReaderin=newBufferedReader(newFileReader(file));
  57. Stringstr;
  58. String[]tempArray;
  59. while((str=in.readLine())!=null){
  60. tempArray=str.split("");
  61. dataArray.add(tempArray);
  62. }
  63. in.close();
  64. }catch(IOExceptione){
  65. e.getStackTrace();
  66. }
  67. totalDatas=newArrayList<>();
  68. for(String[]array:dataArray){
  69. totalDatas.add(array);
  70. }
  71. attrNames=totalDatas.get(0);
  72. minSupportCount=(int)(minSupportRate*totalDatas.size());
  73. attributeReplace();
  74. }
  75. /**
  76. *属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
  77. */
  78. privatevoidattributeReplace(){
  79. intcurrentValue=1;
  80. intnum=0;
  81. Strings;
  82. //属性名到数字的映射图
  83. attr2Num=newHashMap<>();
  84. num2Attr=newHashMap<>();
  85. classTypes=newArrayList<>();
  86. //按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
  87. for(intj=1;j<attrNames.length;j++){
  88. for(inti=1;i<totalDatas.size();i++){
  89. s=totalDatas.get(i)[j];
  90. //如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
  91. if(attrNames[j].equals(AGE)){
  92. num=Integer.parseInt(s);
  93. if(num<=20&&num>0){
  94. totalDatas.get(i)[j]=AGE_YOUNG;
  95. }elseif(num>20&&num<=40){
  96. totalDatas.get(i)[j]=AGE_MIDDLE_AGED;
  97. }elseif(num>40){
  98. totalDatas.get(i)[j]=AGE_Senior;
  99. }
  100. }
  101. if(!attr2Num.containsKey(totalDatas.get(i)[j])){
  102. attr2Num.put(totalDatas.get(i)[j],currentValue);
  103. num2Attr.put(currentValue,totalDatas.get(i)[j]);
  104. if(j==attrNames.length-1){
  105. //如果是组后一列,说明是分类类别列,记录下来
  106. classTypes.add(currentValue);
  107. }
  108. currentValue++;
  109. }
  110. }
  111. }
  112. //对原始的数据作属性替换,每条记录变为类似于事务数据的形式
  113. for(inti=1;i<totalDatas.size();i++){
  114. for(intj=1;j<attrNames.length;j++){
  115. s=totalDatas.get(i)[j];
  116. if(attr2Num.containsKey(s)){
  117. totalDatas.get(i)[j]=attr2Num.get(s)+"";
  118. }
  119. }
  120. }
  121. }
  122. /**
  123. *Apriori计算全部频繁项集
  124. *@return
  125. */
  126. privateArrayList<FrequentItem>aprioriCalculate(){
  127. String[]tempArray;
  128. ArrayList<FrequentItem>totalFrequentItems;
  129. ArrayList<String[]>copyData=(ArrayList<String[]>)totalDatas.clone();
  130. //去除属性名称行
  131. copyData.remove(0);
  132. //去除首列ID
  133. for(inti=0;i<copyData.size();i++){
  134. String[]array=copyData.get(i);
  135. tempArray=newString[array.length-1];
  136. System.arraycopy(array,1,tempArray,0,tempArray.length);
  137. copyData.set(i,tempArray);
  138. }
  139. aprioriTool=newAprioriTool(copyData,minSupportCount);
  140. aprioriTool.computeLink();
  141. totalFrequentItems=aprioriTool.getTotalFrequentItems();
  142. returntotalFrequentItems;
  143. }
  144. /**
  145. *基于关联规则的分类
  146. *
  147. *@paramattrValues
  148. *预先知道的一些属性
  149. *@return
  150. */
  151. publicStringCBAJudge(StringattrValues){
  152. intvalue=0;
  153. //最终分类类别
  154. StringclassType=null;
  155. String[]tempArray;
  156. //已知的属性值
  157. ArrayList<String>attrValueList=newArrayList<>();
  158. ArrayList<FrequentItem>totalFrequentItems;
  159. totalFrequentItems=aprioriCalculate();
  160. //将查询条件进行逐一属性的分割
  161. String[]array=attrValues.split(",");
  162. for(Stringrecord:array){
  163. tempArray=record.split("=");
  164. value=attr2Num.get(tempArray[1]);
  165. attrValueList.add(value+"");
  166. }
  167. //在频繁项集中寻找符合条件的项
  168. for(FrequentItemitem:totalFrequentItems){
  169. //过滤掉不满足个数频繁项
  170. if(item.getIdArray().length<(attrValueList.size()+1)){
  171. continue;
  172. }
  173. //要保证查询的属性都包含在频繁项集中
  174. if(itemIsSatisfied(item,attrValueList)){
  175. tempArray=item.getIdArray();
  176. classType=classificationBaseRules(tempArray);
  177. if(classType!=null){
  178. //作属性替换
  179. classType=num2Attr.get(Integer.parseInt(classType));
  180. break;
  181. }
  182. }
  183. }
  184. returnclassType;
  185. }
  186. /**
  187. *基于关联规则进行分类
  188. *
  189. *@paramitems
  190. *频繁项
  191. *@return
  192. */
  193. privateStringclassificationBaseRules(String[]items){
  194. StringclassType=null;
  195. String[]arrayTemp;
  196. intcount1=0;
  197. intcount2=0;
  198. //置信度
  199. doubleconfidenceRate;
  200. String[]noClassTypeItems=newString[items.length-1];
  201. for(inti=0,k=0;i<items.length;i++){
  202. if(!classTypes.contains(Integer.parseInt(items[i]))){
  203. noClassTypeItems[k]=items[i];
  204. k++;
  205. }else{
  206. classType=items[i];
  207. }
  208. }
  209. for(String[]array:totalDatas){
  210. //去除ID数字号
  211. arrayTemp=newString[array.length-1];
  212. System.arraycopy(array,1,arrayTemp,0,array.length-1);
  213. if(isStrArrayContain(arrayTemp,noClassTypeItems)){
  214. count1++;
  215. if(isStrArrayContain(arrayTemp,items)){
  216. count2++;
  217. }
  218. }
  219. }
  220. //做置信度的计算
  221. confidenceRate=count1*1.0/count2;
  222. if(confidenceRate>=minConf){
  223. returnclassType;
  224. }else{
  225. //如果不满足最小置信度要求,则此关联规则无效
  226. returnnull;
  227. }
  228. }
  229. /**
  230. *判断单个字符是否包含在字符数组中
  231. *
  232. *@paramarray
  233. *字符数组
  234. *@params
  235. *判断的单字符
  236. *@return
  237. */
  238. privatebooleanstrIsContained(String[]array,Strings){
  239. booleanisContained=false;
  240. for(Stringstr:array){
  241. if(str.equals(s)){
  242. isContained=true;
  243. break;
  244. }
  245. }
  246. returnisContained;
  247. }
  248. /**
  249. *数组array2是否包含于array1中,不需要完全一样
  250. *
  251. *@paramarray1
  252. *@paramarray2
  253. *@return
  254. */
  255. privatebooleanisStrArrayContain(String[]array1,String[]array2){
  256. booleanisContain=true;
  257. for(Strings2:array2){
  258. isContain=false;
  259. for(Strings1:array1){
  260. //只要s2字符存在于array1中,这个字符就算包含在array1中
  261. if(s2.equals(s1)){
  262. isContain=true;
  263. break;
  264. }
  265. }
  266. //一旦发现不包含的字符,则array2数组不包含于array1中
  267. if(!isContain){
  268. break;
  269. }
  270. }
  271. returnisContain;
  272. }
  273. /**
  274. *判断频繁项集是否满足查询
  275. *
  276. *@paramitem
  277. *待判断的频繁项集
  278. *@paramattrValues
  279. *查询的属性值列表
  280. *@return
  281. */
  282. privatebooleanitemIsSatisfied(FrequentItemitem,
  283. ArrayList<String>attrValues){
  284. booleanisContained=false;
  285. String[]array=item.getIdArray();
  286. for(Strings:attrValues){
  287. isContained=true;
  288. if(!strIsContained(array,s)){
  289. isContained=false;
  290. break;
  291. }
  292. if(!isContained){
  293. break;
  294. }
  295. }
  296. if(isContained){
  297. isContained=false;
  298. //还要验证是否频繁项集中是否包含分类属性
  299. for(Integertype:classTypes){
  300. if(strIsContained(array,type+"")){
  301. isContained=true;
  302. break;
  303. }
  304. }
  305. }
  306. returnisContained;
  307. }
  308. }
调用类Client.java:

  1. packageDataMining_CBA;
  2. importjava.text.MessageFormat;
  3. /**
  4. *CBA算法--基于关联规则的分类算法
  5. *@authorlyq
  6. *
  7. */
  8. publicclassClient{
  9. publicstaticvoidmain(String[]args){
  10. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  11. StringattrDesc="Age=Senior,CreditRating=Fair";
  12. Stringclassification=null;
  13. //最小支持度阈值率
  14. doubleminSupportRate=0.2;
  15. //最小置信度阈值
  16. doubleminConf=0.7;
  17. CBATooltool=newCBATool(filePath,minSupportRate,minConf);
  18. classification=tool.CBAJudge(attrDesc);
  19. System.out.println(MessageFormat.format("{0}的关联分类结果为{1}",attrDesc,classification));
  20. }
  21. }
代码的结果为:

  1. 频繁1项集:
  2. {1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
  3. 频繁2项集:
  4. {1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
  5. 频繁3项集:
  6. {1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
  7. 频繁4项集:
  8. 频繁5项集:
  9. 频繁6项集:
  10. 频繁7项集:
  11. 频繁8项集:
  12. 频繁9项集:
  13. 频繁10项集:
  14. 频繁11项集:
  15. Age=Senior,CreditRating=Fair的关联分类结果为CLassYes
上面的有些项集为空说明没有此项集。Apriori算法类可以在 这里进行查阅,这里只展示了CBA算法的部分。

算法的分析

我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。

程序实现时遇到的问题

在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。

我对CBA算法的理解

CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。

关闭

CBA算法---基于关联规则进行分类的算法

标签:机器学习数据挖掘算法数据
823人阅读 评论(0) 收藏 举报
分类:
算法(44) 数据挖掘(32) 机器学习(30)

目录(?)[+]

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。

算法原理

CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:

1、输入数据记录,就是一条条的属性值。

2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。

3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。

4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。

这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 113HighNoFairCLassNo
  3. 211HighNoExcellentCLassNo
  4. 325HighNoFairCLassYes
  5. 445MediumNoFairCLassYes
  6. 550LowYesFairCLassYes
  7. 651LowYesExcellentCLassNo
  8. 730LowYesExcellentCLassYes
  9. 813MediumNoFairCLassNo
  10. 99LowYesFairCLassYes
  11. 1055MediumYesFairCLassYes
  12. 1114MediumYesExcellentCLassYes
  13. 1233MediumNoExcellentCLassYes
  14. 1333HighYesFairCLassYes
  15. 1441MediumNoExcellentCLassNo
属性值对应的数字替换图:

  1. Medium=5,CLassYes=12,Excellent=10,Low=6,Fair=9,CLassNo=11,Young=1,Middle_aged=2,Yes=8,No=7,High=4,Senior=3
体会之后的数据变为了下面的事务项:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 1147911
  3. 21471011
  4. 3247912
  5. 4357912
  6. 5368912
  7. 63681011
  8. 72681012
  9. 8157911
  10. 9168912
  11. 10358912
  12. 111581012
  13. 122571012
  14. 13248912
  15. 143571011
把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以 点击这里进行了解。

算法的代码实现

测试数据就是上面的内容。

CBATool.java:

  1. packageDataMining_CBA;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.HashMap;
  8. importjava.util.regex.Matcher;
  9. importjava.util.regex.Pattern;
  10. importDataMining_CBA.AprioriTool.AprioriTool;
  11. importDataMining_CBA.AprioriTool.FrequentItem;
  12. /**
  13. *CBA算法(关联规则分类)工具类
  14. *
  15. *@authorlyq
  16. *
  17. */
  18. publicclassCBATool{
  19. //年龄的类别划分
  20. publicfinalStringAGE="Age";
  21. publicfinalStringAGE_YOUNG="Young";
  22. publicfinalStringAGE_MIDDLE_AGED="Middle_aged";
  23. publicfinalStringAGE_Senior="Senior";
  24. //测试数据地址
  25. privateStringfilePath;
  26. //最小支持度阈值率
  27. privatedoubleminSupportRate;
  28. //最小置信度阈值,用来判断是否能够成为关联规则
  29. privatedoubleminConf;
  30. //最小支持度
  31. privateintminSupportCount;
  32. //属性列名称
  33. privateString[]attrNames;
  34. //类别属性所代表的数字集合
  35. privateArrayList<Integer>classTypes;
  36. //用二维数组保存测试数据
  37. privateArrayList<String[]>totalDatas;
  38. //Apriori算法工具类
  39. privateAprioriToolaprioriTool;
  40. //属性到数字的映射图
  41. privateHashMap<String,Integer>attr2Num;
  42. privateHashMap<Integer,String>num2Attr;
  43. publicCBATool(StringfilePath,doubleminSupportRate,doubleminConf){
  44. this.filePath=filePath;
  45. this.minConf=minConf;
  46. this.minSupportRate=minSupportRate;
  47. readDataFile();
  48. }
  49. /**
  50. *从文件中读取数据
  51. */
  52. privatevoidreadDataFile(){
  53. Filefile=newFile(filePath);
  54. ArrayList<String[]>dataArray=newArrayList<String[]>();
  55. try{
  56. BufferedReaderin=newBufferedReader(newFileReader(file));
  57. Stringstr;
  58. String[]tempArray;
  59. while((str=in.readLine())!=null){
  60. tempArray=str.split("");
  61. dataArray.add(tempArray);
  62. }
  63. in.close();
  64. }catch(IOExceptione){
  65. e.getStackTrace();
  66. }
  67. totalDatas=newArrayList<>();
  68. for(String[]array:dataArray){
  69. totalDatas.add(array);
  70. }
  71. attrNames=totalDatas.get(0);
  72. minSupportCount=(int)(minSupportRate*totalDatas.size());
  73. attributeReplace();
  74. }
  75. /**
  76. *属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
  77. */
  78. privatevoidattributeReplace(){
  79. intcurrentValue=1;
  80. intnum=0;
  81. Strings;
  82. //属性名到数字的映射图
  83. attr2Num=newHashMap<>();
  84. num2Attr=newHashMap<>();
  85. classTypes=newArrayList<>();
  86. //按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
  87. for(intj=1;j<attrNames.length;j++){
  88. for(inti=1;i<totalDatas.size();i++){
  89. s=totalDatas.get(i)[j];
  90. //如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
  91. if(attrNames[j].equals(AGE)){
  92. num=Integer.parseInt(s);
  93. if(num<=20&&num>0){
  94. totalDatas.get(i)[j]=AGE_YOUNG;
  95. }elseif(num>20&&num<=40){
  96. totalDatas.get(i)[j]=AGE_MIDDLE_AGED;
  97. }elseif(num>40){
  98. totalDatas.get(i)[j]=AGE_Senior;
  99. }
  100. }
  101. if(!attr2Num.containsKey(totalDatas.get(i)[j])){
  102. attr2Num.put(totalDatas.get(i)[j],currentValue);
  103. num2Attr.put(currentValue,totalDatas.get(i)[j]);
  104. if(j==attrNames.length-1){
  105. //如果是组后一列,说明是分类类别列,记录下来
  106. classTypes.add(currentValue);
  107. }
  108. currentValue++;
  109. }
  110. }
  111. }
  112. //对原始的数据作属性替换,每条记录变为类似于事务数据的形式
  113. for(inti=1;i<totalDatas.size();i++){
  114. for(intj=1;j<attrNames.length;j++){
  115. s=totalDatas.get(i)[j];
  116. if(attr2Num.containsKey(s)){
  117. totalDatas.get(i)[j]=attr2Num.get(s)+"";
  118. }
  119. }
  120. }
  121. }
  122. /**
  123. *Apriori计算全部频繁项集
  124. *@return
  125. */
  126. privateArrayList<FrequentItem>aprioriCalculate(){
  127. String[]tempArray;
  128. ArrayList<FrequentItem>totalFrequentItems;
  129. ArrayList<String[]>copyData=(ArrayList<String[]>)totalDatas.clone();
  130. //去除属性名称行
  131. copyData.remove(0);
  132. //去除首列ID
  133. for(inti=0;i<copyData.size();i++){
  134. String[]array=copyData.get(i);
  135. tempArray=newString[array.length-1];
  136. System.arraycopy(array,1,tempArray,0,tempArray.length);
  137. copyData.set(i,tempArray);
  138. }
  139. aprioriTool=newAprioriTool(copyData,minSupportCount);
  140. aprioriTool.computeLink();
  141. totalFrequentItems=aprioriTool.getTotalFrequentItems();
  142. returntotalFrequentItems;
  143. }
  144. /**
  145. *基于关联规则的分类
  146. *
  147. *@paramattrValues
  148. *预先知道的一些属性
  149. *@return
  150. */
  151. publicStringCBAJudge(StringattrValues){
  152. intvalue=0;
  153. //最终分类类别
  154. StringclassType=null;
  155. String[]tempArray;
  156. //已知的属性值
  157. ArrayList<String>attrValueList=newArrayList<>();
  158. ArrayList<FrequentItem>totalFrequentItems;
  159. totalFrequentItems=aprioriCalculate();
  160. //将查询条件进行逐一属性的分割
  161. String[]array=attrValues.split(",");
  162. for(Stringrecord:array){
  163. tempArray=record.split("=");
  164. value=attr2Num.get(tempArray[1]);
  165. attrValueList.add(value+"");
  166. }
  167. //在频繁项集中寻找符合条件的项
  168. for(FrequentItemitem:totalFrequentItems){
  169. //过滤掉不满足个数频繁项
  170. if(item.getIdArray().length<(attrValueList.size()+1)){
  171. continue;
  172. }
  173. //要保证查询的属性都包含在频繁项集中
  174. if(itemIsSatisfied(item,attrValueList)){
  175. tempArray=item.getIdArray();
  176. classType=classificationBaseRules(tempArray);
  177. if(classType!=null){
  178. //作属性替换
  179. classType=num2Attr.get(Integer.parseInt(classType));
  180. break;
  181. }
  182. }
  183. }
  184. returnclassType;
  185. }
  186. /**
  187. *基于关联规则进行分类
  188. *
  189. *@paramitems
  190. *频繁项
  191. *@return
  192. */
  193. privateStringclassificationBaseRules(String[]items){
  194. StringclassType=null;
  195. String[]arrayTemp;
  196. intcount1=0;
  197. intcount2=0;
  198. //置信度
  199. doubleconfidenceRate;
  200. String[]noClassTypeItems=newString[items.length-1];
  201. for(inti=0,k=0;i<items.length;i++){
  202. if(!classTypes.contains(Integer.parseInt(items[i]))){
  203. noClassTypeItems[k]=items[i];
  204. k++;
  205. }else{
  206. classType=items[i];
  207. }
  208. }
  209. for(String[]array:totalDatas){
  210. //去除ID数字号
  211. arrayTemp=newString[array.length-1];
  212. System.arraycopy(array,1,arrayTemp,0,array.length-1);
  213. if(isStrArrayContain(arrayTemp,noClassTypeItems)){
  214. count1++;
  215. if(isStrArrayContain(arrayTemp,items)){
  216. count2++;
  217. }
  218. }
  219. }
  220. //做置信度的计算
  221. confidenceRate=count1*1.0/count2;
  222. if(confidenceRate>=minConf){
  223. returnclassType;
  224. }else{
  225. //如果不满足最小置信度要求,则此关联规则无效
  226. returnnull;
  227. }
  228. }
  229. /**
  230. *判断单个字符是否包含在字符数组中
  231. *
  232. *@paramarray
  233. *字符数组
  234. *@params
  235. *判断的单字符
  236. *@return
  237. */
  238. privatebooleanstrIsContained(String[]array,Strings){
  239. booleanisContained=false;
  240. for(Stringstr:array){
  241. if(str.equals(s)){
  242. isContained=true;
  243. break;
  244. }
  245. }
  246. returnisContained;
  247. }
  248. /**
  249. *数组array2是否包含于array1中,不需要完全一样
  250. *
  251. *@paramarray1
  252. *@paramarray2
  253. *@return
  254. */
  255. privatebooleanisStrArrayContain(String[]array1,String[]array2){
  256. booleanisContain=true;
  257. for(Strings2:array2){
  258. isContain=false;
  259. for(Strings1:array1){
  260. //只要s2字符存在于array1中,这个字符就算包含在array1中
  261. if(s2.equals(s1)){
  262. isContain=true;
  263. break;
  264. }
  265. }
  266. //一旦发现不包含的字符,则array2数组不包含于array1中
  267. if(!isContain){
  268. break;
  269. }
  270. }
  271. returnisContain;
  272. }
  273. /**
  274. *判断频繁项集是否满足查询
  275. *
  276. *@paramitem
  277. *待判断的频繁项集
  278. *@paramattrValues
  279. *查询的属性值列表
  280. *@return
  281. */
  282. privatebooleanitemIsSatisfied(FrequentItemitem,
  283. ArrayList<String>attrValues){
  284. booleanisContained=false;
  285. String[]array=item.getIdArray();
  286. for(Strings:attrValues){
  287. isContained=true;
  288. if(!strIsContained(array,s)){
  289. isContained=false;
  290. break;
  291. }
  292. if(!isContained){
  293. break;
  294. }
  295. }
  296. if(isContained){
  297. isContained=false;
  298. //还要验证是否频繁项集中是否包含分类属性
  299. for(Integertype:classTypes){
  300. if(strIsContained(array,type+"")){
  301. isContained=true;
  302. break;
  303. }
  304. }
  305. }
  306. returnisContained;
  307. }
  308. }
调用类Client.java:

  1. packageDataMining_CBA;
  2. importjava.text.MessageFormat;
  3. /**
  4. *CBA算法--基于关联规则的分类算法
  5. *@authorlyq
  6. *
  7. */
  8. publicclassClient{
  9. publicstaticvoidmain(String[]args){
  10. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  11. StringattrDesc="Age=Senior,CreditRating=Fair";
  12. Stringclassification=null;
  13. //最小支持度阈值率
  14. doubleminSupportRate=0.2;
  15. //最小置信度阈值
  16. doubleminConf=0.7;
  17. CBATooltool=newCBATool(filePath,minSupportRate,minConf);
  18. classification=tool.CBAJudge(attrDesc);
  19. System.out.println(MessageFormat.format("{0}的关联分类结果为{1}",attrDesc,classification));
  20. }
  21. }
代码的结果为:

  1. 频繁1项集:
  2. {1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
  3. 频繁2项集:
  4. {1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
  5. 频繁3项集:
  6. {1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
  7. 频繁4项集:
  8. 频繁5项集:
  9. 频繁6项集:
  10. 频繁7项集:
  11. 频繁8项集:
  12. 频繁9项集:
  13. 频繁10项集:
  14. 频繁11项集:
  15. Age=Senior,CreditRating=Fair的关联分类结果为CLassYes
上面的有些项集为空说明没有此项集。Apriori算法类可以在 这里进行查阅,这里只展示了CBA算法的部分。

算法的分析

我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。

程序实现时遇到的问题

在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。

我对CBA算法的理解

CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。

算法原理

CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:

1、输入数据记录,就是一条条的属性值。

2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。

3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。

4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。

这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 113HighNoFairCLassNo
  3. 211HighNoExcellentCLassNo
  4. 325HighNoFairCLassYes
  5. 445MediumNoFairCLassYes
  6. 550LowYesFairCLassYes
  7. 651LowYesExcellentCLassNo
  8. 730LowYesExcellentCLassYes
  9. 813MediumNoFairCLassNo
  10. 99LowYesFairCLassYes
  11. 1055MediumYesFairCLassYes
  12. 1114MediumYesExcellentCLassYes
  13. 1233MediumNoExcellentCLassYes
  14. 1333HighYesFairCLassYes
  15. 1441MediumNoExcellentCLassNo
属性值对应的数字替换图:

  1. Medium=5,CLassYes=12,Excellent=10,Low=6,Fair=9,CLassNo=11,Young=1,Middle_aged=2,Yes=8,No=7,High=4,Senior=3
体会之后的数据变为了下面的事务项:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 1147911
  3. 21471011
  4. 3247912
  5. 4357912
  6. 5368912
  7. 63681011
  8. 72681012
  9. 8157911
  10. 9168912
  11. 10358912
  12. 111581012
  13. 122571012
  14. 13248912
  15. 143571011
把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以 点击这里进行了解。

算法的代码实现

测试数据就是上面的内容。

CBATool.java:

  1. packageDataMining_CBA;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.HashMap;
  8. importjava.util.regex.Matcher;
  9. importjava.util.regex.Pattern;
  10. importDataMining_CBA.AprioriTool.AprioriTool;
  11. importDataMining_CBA.AprioriTool.FrequentItem;
  12. /**
  13. *CBA算法(关联规则分类)工具类
  14. *
  15. *@authorlyq
  16. *
  17. */
  18. publicclassCBATool{
  19. //年龄的类别划分
  20. publicfinalStringAGE="Age";
  21. publicfinalStringAGE_YOUNG="Young";
  22. publicfinalStringAGE_MIDDLE_AGED="Middle_aged";
  23. publicfinalStringAGE_Senior="Senior";
  24. //测试数据地址
  25. privateStringfilePath;
  26. //最小支持度阈值率
  27. privatedoubleminSupportRate;
  28. //最小置信度阈值,用来判断是否能够成为关联规则
  29. privatedoubleminConf;
  30. //最小支持度
  31. privateintminSupportCount;
  32. //属性列名称
  33. privateString[]attrNames;
  34. //类别属性所代表的数字集合
  35. privateArrayList<Integer>classTypes;
  36. //用二维数组保存测试数据
  37. privateArrayList<String[]>totalDatas;
  38. //Apriori算法工具类
  39. privateAprioriToolaprioriTool;
  40. //属性到数字的映射图
  41. privateHashMap<String,Integer>attr2Num;
  42. privateHashMap<Integer,String>num2Attr;
  43. publicCBATool(StringfilePath,doubleminSupportRate,doubleminConf){
  44. this.filePath=filePath;
  45. this.minConf=minConf;
  46. this.minSupportRate=minSupportRate;
  47. readDataFile();
  48. }
  49. /**
  50. *从文件中读取数据
  51. */
  52. privatevoidreadDataFile(){
  53. Filefile=newFile(filePath);
  54. ArrayList<String[]>dataArray=newArrayList<String[]>();
  55. try{
  56. BufferedReaderin=newBufferedReader(newFileReader(file));
  57. Stringstr;
  58. String[]tempArray;
  59. while((str=in.readLine())!=null){
  60. tempArray=str.split("");
  61. dataArray.add(tempArray);
  62. }
  63. in.close();
  64. }catch(IOExceptione){
  65. e.getStackTrace();
  66. }
  67. totalDatas=newArrayList<>();
  68. for(String[]array:dataArray){
  69. totalDatas.add(array);
  70. }
  71. attrNames=totalDatas.get(0);
  72. minSupportCount=(int)(minSupportRate*totalDatas.size());
  73. attributeReplace();
  74. }
  75. /**
  76. *属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
  77. */
  78. privatevoidattributeReplace(){
  79. intcurrentValue=1;
  80. intnum=0;
  81. Strings;
  82. //属性名到数字的映射图
  83. attr2Num=newHashMap<>();
  84. num2Attr=newHashMap<>();
  85. classTypes=newArrayList<>();
  86. //按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
  87. for(intj=1;j<attrNames.length;j++){
  88. for(inti=1;i<totalDatas.size();i++){
  89. s=totalDatas.get(i)[j];
  90. //如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
  91. if(attrNames[j].equals(AGE)){
  92. num=Integer.parseInt(s);
  93. if(num<=20&&num>0){
  94. totalDatas.get(i)[j]=AGE_YOUNG;
  95. }elseif(num>20&&num<=40){
  96. totalDatas.get(i)[j]=AGE_MIDDLE_AGED;
  97. }elseif(num>40){
  98. totalDatas.get(i)[j]=AGE_Senior;
  99. }
  100. }
  101. if(!attr2Num.containsKey(totalDatas.get(i)[j])){
  102. attr2Num.put(totalDatas.get(i)[j],currentValue);
  103. num2Attr.put(currentValue,totalDatas.get(i)[j]);
  104. if(j==attrNames.length-1){
  105. //如果是组后一列,说明是分类类别列,记录下来
  106. classTypes.add(currentValue);
  107. }
  108. currentValue++;
  109. }
  110. }
  111. }
  112. //对原始的数据作属性替换,每条记录变为类似于事务数据的形式
  113. for(inti=1;i<totalDatas.size();i++){
  114. for(intj=1;j<attrNames.length;j++){
  115. s=totalDatas.get(i)[j];
  116. if(attr2Num.containsKey(s)){
  117. totalDatas.get(i)[j]=attr2Num.get(s)+"";
  118. }
  119. }
  120. }
  121. }
  122. /**
  123. *Apriori计算全部频繁项集
  124. *@return
  125. */
  126. privateArrayList<FrequentItem>aprioriCalculate(){
  127. String[]tempArray;
  128. ArrayList<FrequentItem>totalFrequentItems;
  129. ArrayList<String[]>copyData=(ArrayList<String[]>)totalDatas.clone();
  130. //去除属性名称行
  131. copyData.remove(0);
  132. //去除首列ID
  133. for(inti=0;i<copyData.size();i++){
  134. String[]array=copyData.get(i);
  135. tempArray=newString[array.length-1];
  136. System.arraycopy(array,1,tempArray,0,tempArray.length);
  137. copyData.set(i,tempArray);
  138. }
  139. aprioriTool=newAprioriTool(copyData,minSupportCount);
  140. aprioriTool.computeLink();
  141. totalFrequentItems=aprioriTool.getTotalFrequentItems();
  142. returntotalFrequentItems;
  143. }
  144. /**
  145. *基于关联规则的分类
  146. *
  147. *@paramattrValues
  148. *预先知道的一些属性
  149. *@return
  150. */
  151. publicStringCBAJudge(StringattrValues){
  152. intvalue=0;
  153. //最终分类类别
  154. StringclassType=null;
  155. String[]tempArray;
  156. //已知的属性值
  157. ArrayList<String>attrValueList=newArrayList<>();
  158. ArrayList<FrequentItem>totalFrequentItems;
  159. totalFrequentItems=aprioriCalculate();
  160. //将查询条件进行逐一属性的分割
  161. String[]array=attrValues.split(",");
  162. for(Stringrecord:array){
  163. tempArray=record.split("=");
  164. value=attr2Num.get(tempArray[1]);
  165. attrValueList.add(value+"");
  166. }
  167. //在频繁项集中寻找符合条件的项
  168. for(FrequentItemitem:totalFrequentItems){
  169. //过滤掉不满足个数频繁项
  170. if(item.getIdArray().length<(attrValueList.size()+1)){
  171. continue;
  172. }
  173. //要保证查询的属性都包含在频繁项集中
  174. if(itemIsSatisfied(item,attrValueList)){
  175. tempArray=item.getIdArray();
  176. classType=classificationBaseRules(tempArray);
  177. if(classType!=null){
  178. //作属性替换
  179. classType=num2Attr.get(Integer.parseInt(classType));
  180. break;
  181. }
  182. }
  183. }
  184. returnclassType;
  185. }
  186. /**
  187. *基于关联规则进行分类
  188. *
  189. *@paramitems
  190. *频繁项
  191. *@return
  192. */
  193. privateStringclassificationBaseRules(String[]items){
  194. StringclassType=null;
  195. String[]arrayTemp;
  196. intcount1=0;
  197. intcount2=0;
  198. //置信度
  199. doubleconfidenceRate;
  200. String[]noClassTypeItems=newString[items.length-1];
  201. for(inti=0,k=0;i<items.length;i++){
  202. if(!classTypes.contains(Integer.parseInt(items[i]))){
  203. noClassTypeItems[k]=items[i];
  204. k++;
  205. }else{
  206. classType=items[i];
  207. }
  208. }
  209. for(String[]array:totalDatas){
  210. //去除ID数字号
  211. arrayTemp=newString[array.length-1];
  212. System.arraycopy(array,1,arrayTemp,0,array.length-1);
  213. if(isStrArrayContain(arrayTemp,noClassTypeItems)){
  214. count1++;
  215. if(isStrArrayContain(arrayTemp,items)){
  216. count2++;
  217. }
  218. }
  219. }
  220. //做置信度的计算
  221. confidenceRate=count1*1.0/count2;
  222. if(confidenceRate>=minConf){
  223. returnclassType;
  224. }else{
  225. //如果不满足最小置信度要求,则此关联规则无效
  226. returnnull;
  227. }
  228. }
  229. /**
  230. *判断单个字符是否包含在字符数组中
  231. *
  232. *@paramarray
  233. *字符数组
  234. *@params
  235. *判断的单字符
  236. *@return
  237. */
  238. privatebooleanstrIsContained(String[]array,Strings){
  239. booleanisContained=false;
  240. for(Stringstr:array){
  241. if(str.equals(s)){
  242. isContained=true;
  243. break;
  244. }
  245. }
  246. returnisContained;
  247. }
  248. /**
  249. *数组array2是否包含于array1中,不需要完全一样
  250. *
  251. *@paramarray1
  252. *@paramarray2
  253. *@return
  254. */
  255. privatebooleanisStrArrayContain(String[]array1,String[]array2){
  256. booleanisContain=true;
  257. for(Strings2:array2){
  258. isContain=false;
  259. for(Strings1:array1){
  260. //只要s2字符存在于array1中,这个字符就算包含在array1中
  261. if(s2.equals(s1)){
  262. isContain=true;
  263. break;
  264. }
  265. }
  266. //一旦发现不包含的字符,则array2数组不包含于array1中
  267. if(!isContain){
  268. break;
  269. }
  270. }
  271. returnisContain;
  272. }
  273. /**
  274. *判断频繁项集是否满足查询
  275. *
  276. *@paramitem
  277. *待判断的频繁项集
  278. *@paramattrValues
  279. *查询的属性值列表
  280. *@return
  281. */
  282. privatebooleanitemIsSatisfied(FrequentItemitem,
  283. ArrayList<String>attrValues){
  284. booleanisContained=false;
  285. String[]array=item.getIdArray();
  286. for(Strings:attrValues){
  287. isContained=true;
  288. if(!strIsContained(array,s)){
  289. isContained=false;
  290. break;
  291. }
  292. if(!isContained){
  293. break;
  294. }
  295. }
  296. if(isContained){
  297. isContained=false;
  298. //还要验证是否频繁项集中是否包含分类属性
  299. for(Integertype:classTypes){
  300. if(strIsContained(array,type+"")){
  301. isContained=true;
  302. break;
  303. }
  304. }
  305. }
  306. returnisContained;
  307. }
  308. }
调用类Client.java:

  1. packageDataMining_CBA;
  2. importjava.text.MessageFormat;
  3. /**
  4. *CBA算法--基于关联规则的分类算法
  5. *@authorlyq
  6. *
  7. */
  8. publicclassClient{
  9. publicstaticvoidmain(String[]args){
  10. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  11. StringattrDesc="Age=Senior,CreditRating=Fair";
  12. Stringclassification=null;
  13. //最小支持度阈值率
  14. doubleminSupportRate=0.2;
  15. //最小置信度阈值
  16. doubleminConf=0.7;
  17. CBATooltool=newCBATool(filePath,minSupportRate,minConf);
  18. classification=tool.CBAJudge(attrDesc);
  19. System.out.println(MessageFormat.format("{0}的关联分类结果为{1}",attrDesc,classification));
  20. }
  21. }
代码的结果为:

  1. 频繁1项集:
  2. {1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
  3. 频繁2项集:
  4. {1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
  5. 频繁3项集:
  6. {1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
  7. 频繁4项集:
  8. 频繁5项集:
  9. 频繁6项集:
  10. 频繁7项集:
  11. 频繁8项集:
  12. 频繁9项集:
  13. 频繁10项集:
  14. 频繁11项集:
  15. Age=Senior,CreditRating=Fair的关联分类结果为CLassYes
上面的有些项集为空说明没有此项集。Apriori算法类可以在 这里进行查阅,这里只展示了CBA算法的部分。

算法的分析

我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。

程序实现时遇到的问题

在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。

我对CBA算法的理解

CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。

算法原理

CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:

1、输入数据记录,就是一条条的属性值。

2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。

3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。

4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。

这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 113HighNoFairCLassNo
  3. 211HighNoExcellentCLassNo
  4. 325HighNoFairCLassYes
  5. 445MediumNoFairCLassYes
  6. 550LowYesFairCLassYes
  7. 651LowYesExcellentCLassNo
  8. 730LowYesExcellentCLassYes
  9. 813MediumNoFairCLassNo
  10. 99LowYesFairCLassYes
  11. 1055MediumYesFairCLassYes
  12. 1114MediumYesExcellentCLassYes
  13. 1233MediumNoExcellentCLassYes
  14. 1333HighYesFairCLassYes
  15. 1441MediumNoExcellentCLassNo
属性值对应的数字替换图:

  1. Medium=5,CLassYes=12,Excellent=10,Low=6,Fair=9,CLassNo=11,Young=1,Middle_aged=2,Yes=8,No=7,High=4,Senior=3
体会之后的数据变为了下面的事务项:

  1. RidAgeIncomeStudentCreditRatingBuysComputer
  2. 1147911
  3. 21471011
  4. 3247912
  5. 4357912
  6. 5368912
  7. 63681011
  8. 72681012
  9. 8157911
  10. 9168912
  11. 10358912
  12. 111581012
  13. 122571012
  14. 13248912
  15. 143571011
把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以 点击这里进行了解。

算法的代码实现

测试数据就是上面的内容。

CBATool.java:

  1. packageDataMining_CBA;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.HashMap;
  8. importjava.util.regex.Matcher;
  9. importjava.util.regex.Pattern;
  10. importDataMining_CBA.AprioriTool.AprioriTool;
  11. importDataMining_CBA.AprioriTool.FrequentItem;
  12. /**
  13. *CBA算法(关联规则分类)工具类
  14. *
  15. *@authorlyq
  16. *
  17. */
  18. publicclassCBATool{
  19. //年龄的类别划分
  20. publicfinalStringAGE="Age";
  21. publicfinalStringAGE_YOUNG="Young";
  22. publicfinalStringAGE_MIDDLE_AGED="Middle_aged";
  23. publicfinalStringAGE_Senior="Senior";
  24. //测试数据地址
  25. privateStringfilePath;
  26. //最小支持度阈值率
  27. privatedoubleminSupportRate;
  28. //最小置信度阈值,用来判断是否能够成为关联规则
  29. privatedoubleminConf;
  30. //最小支持度
  31. privateintminSupportCount;
  32. //属性列名称
  33. privateString[]attrNames;
  34. //类别属性所代表的数字集合
  35. privateArrayList<Integer>classTypes;
  36. //用二维数组保存测试数据
  37. privateArrayList<String[]>totalDatas;
  38. //Apriori算法工具类
  39. privateAprioriToolaprioriTool;
  40. //属性到数字的映射图
  41. privateHashMap<String,Integer>attr2Num;
  42. privateHashMap<Integer,String>num2Attr;
  43. publicCBATool(StringfilePath,doubleminSupportRate,doubleminConf){
  44. this.filePath=filePath;
  45. this.minConf=minConf;
  46. this.minSupportRate=minSupportRate;
  47. readDataFile();
  48. }
  49. /**
  50. *从文件中读取数据
  51. */
  52. privatevoidreadDataFile(){
  53. Filefile=newFile(filePath);
  54. ArrayList<String[]>dataArray=newArrayList<String[]>();
  55. try{
  56. BufferedReaderin=newBufferedReader(newFileReader(file));
  57. Stringstr;
  58. String[]tempArray;
  59. while((str=in.readLine())!=null){
  60. tempArray=str.split("");
  61. dataArray.add(tempArray);
  62. }
  63. in.close();
  64. }catch(IOExceptione){
  65. e.getStackTrace();
  66. }
  67. totalDatas=newArrayList<>();
  68. for(String[]array:dataArray){
  69. totalDatas.add(array);
  70. }
  71. attrNames=totalDatas.get(0);
  72. minSupportCount=(int)(minSupportRate*totalDatas.size());
  73. attributeReplace();
  74. }
  75. /**
  76. *属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
  77. */
  78. privatevoidattributeReplace(){
  79. intcurrentValue=1;
  80. intnum=0;
  81. Strings;
  82. //属性名到数字的映射图
  83. attr2Num=newHashMap<>();
  84. num2Attr=newHashMap<>();
  85. classTypes=newArrayList<>();
  86. //按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
  87. for(intj=1;j<attrNames.length;j++){
  88. for(inti=1;i<totalDatas.size();i++){
  89. s=totalDatas.get(i)[j];
  90. //如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
  91. if(attrNames[j].equals(AGE)){
  92. num=Integer.parseInt(s);
  93. if(num<=20&&num>0){
  94. totalDatas.get(i)[j]=AGE_YOUNG;
  95. }elseif(num>20&&num<=40){
  96. totalDatas.get(i)[j]=AGE_MIDDLE_AGED;
  97. }elseif(num>40){
  98. totalDatas.get(i)[j]=AGE_Senior;
  99. }
  100. }
  101. if(!attr2Num.containsKey(totalDatas.get(i)[j])){
  102. attr2Num.put(totalDatas.get(i)[j],currentValue);
  103. num2Attr.put(currentValue,totalDatas.get(i)[j]);
  104. if(j==attrNames.length-1){
  105. //如果是组后一列,说明是分类类别列,记录下来
  106. classTypes.add(currentValue);
  107. }
  108. currentValue++;
  109. }
  110. }
  111. }
  112. //对原始的数据作属性替换,每条记录变为类似于事务数据的形式
  113. for(inti=1;i<totalDatas.size();i++){
  114. for(intj=1;j<attrNames.length;j++){
  115. s=totalDatas.get(i)[j];
  116. if(attr2Num.containsKey(s)){
  117. totalDatas.get(i)[j]=attr2Num.get(s)+"";
  118. }
  119. }
  120. }
  121. }
  122. /**
  123. *Apriori计算全部频繁项集
  124. *@return
  125. */
  126. privateArrayList<FrequentItem>aprioriCalculate(){
  127. String[]tempArray;
  128. ArrayList<FrequentItem>totalFrequentItems;
  129. ArrayList<String[]>copyData=(ArrayList<String[]>)totalDatas.clone();
  130. //去除属性名称行
  131. copyData.remove(0);
  132. //去除首列ID
  133. for(inti=0;i<copyData.size();i++){
  134. String[]array=copyData.get(i);
  135. tempArray=newString[array.length-1];
  136. System.arraycopy(array,1,tempArray,0,tempArray.length);
  137. copyData.set(i,tempArray);
  138. }
  139. aprioriTool=newAprioriTool(copyData,minSupportCount);
  140. aprioriTool.computeLink();
  141. totalFrequentItems=aprioriTool.getTotalFrequentItems();
  142. returntotalFrequentItems;
  143. }
  144. /**
  145. *基于关联规则的分类
  146. *
  147. *@paramattrValues
  148. *预先知道的一些属性
  149. *@return
  150. */
  151. publicStringCBAJudge(StringattrValues){
  152. intvalue=0;
  153. //最终分类类别
  154. StringclassType=null;
  155. String[]tempArray;
  156. //已知的属性值
  157. ArrayList<String>attrValueList=newArrayList<>();
  158. ArrayList<FrequentItem>totalFrequentItems;
  159. totalFrequentItems=aprioriCalculate();
  160. //将查询条件进行逐一属性的分割
  161. String[]array=attrValues.split(",");
  162. for(Stringrecord:array){
  163. tempArray=record.split("=");
  164. value=attr2Num.get(tempArray[1]);
  165. attrValueList.add(value+"");
  166. }
  167. //在频繁项集中寻找符合条件的项
  168. for(FrequentItemitem:totalFrequentItems){
  169. //过滤掉不满足个数频繁项
  170. if(item.getIdArray().length<(attrValueList.size()+1)){
  171. continue;
  172. }
  173. //要保证查询的属性都包含在频繁项集中
  174. if(itemIsSatisfied(item,attrValueList)){
  175. tempArray=item.getIdArray();
  176. classType=classificationBaseRules(tempArray);
  177. if(classType!=null){
  178. //作属性替换
  179. classType=num2Attr.get(Integer.parseInt(classType));
  180. break;
  181. }
  182. }
  183. }
  184. returnclassType;
  185. }
  186. /**
  187. *基于关联规则进行分类
  188. *
  189. *@paramitems
  190. *频繁项
  191. *@return
  192. */
  193. privateStringclassificationBaseRules(String[]items){
  194. StringclassType=null;
  195. String[]arrayTemp;
  196. intcount1=0;
  197. intcount2=0;
  198. //置信度
  199. doubleconfidenceRate;
  200. String[]noClassTypeItems=newString[items.length-1];
  201. for(inti=0,k=0;i<items.length;i++){
  202. if(!classTypes.contains(Integer.parseInt(items[i]))){
  203. noClassTypeItems[k]=items[i];
  204. k++;
  205. }else{
  206. classType=items[i];
  207. }
  208. }
  209. for(String[]array:totalDatas){
  210. //去除ID数字号
  211. arrayTemp=newString[array.length-1];
  212. System.arraycopy(array,1,arrayTemp,0,array.length-1);
  213. if(isStrArrayContain(arrayTemp,noClassTypeItems)){
  214. count1++;
  215. if(isStrArrayContain(arrayTemp,items)){
  216. count2++;
  217. }
  218. }
  219. }
  220. //做置信度的计算
  221. confidenceRate=count1*1.0/count2;
  222. if(confidenceRate>=minConf){
  223. returnclassType;
  224. }else{
  225. //如果不满足最小置信度要求,则此关联规则无效
  226. returnnull;
  227. }
  228. }
  229. /**
  230. *判断单个字符是否包含在字符数组中
  231. *
  232. *@paramarray
  233. *字符数组
  234. *@params
  235. *判断的单字符
  236. *@return
  237. */
  238. privatebooleanstrIsContained(String[]array,Strings){
  239. booleanisContained=false;
  240. for(Stringstr:array){
  241. if(str.equals(s)){
  242. isContained=true;
  243. break;
  244. }
  245. }
  246. returnisContained;
  247. }
  248. /**
  249. *数组array2是否包含于array1中,不需要完全一样
  250. *
  251. *@paramarray1
  252. *@paramarray2
  253. *@return
  254. */
  255. privatebooleanisStrArrayContain(String[]array1,String[]array2){
  256. booleanisContain=true;
  257. for(Strings2:array2){
  258. isContain=false;
  259. for(Strings1:array1){
  260. //只要s2字符存在于array1中,这个字符就算包含在array1中
  261. if(s2.equals(s1)){
  262. isContain=true;
  263. break;
  264. }
  265. }
  266. //一旦发现不包含的字符,则array2数组不包含于array1中
  267. if(!isContain){
  268. break;
  269. }
  270. }
  271. returnisContain;
  272. }
  273. /**
  274. *判断频繁项集是否满足查询
  275. *
  276. *@paramitem
  277. *待判断的频繁项集
  278. *@paramattrValues
  279. *查询的属性值列表
  280. *@return
  281. */
  282. privatebooleanitemIsSatisfied(FrequentItemitem,
  283. ArrayList<String>attrValues){
  284. booleanisContained=false;
  285. String[]array=item.getIdArray();
  286. for(Strings:attrValues){
  287. isContained=true;
  288. if(!strIsContained(array,s)){
  289. isContained=false;
  290. break;
  291. }
  292. if(!isContained){
  293. break;
  294. }
  295. }
  296. if(isContained){
  297. isContained=false;
  298. //还要验证是否频繁项集中是否包含分类属性
  299. for(Integertype:classTypes){
  300. if(strIsContained(array,type+"")){
  301. isContained=true;
  302. break;
  303. }
  304. }
  305. }
  306. returnisContained;
  307. }
  308. }
调用类Client.java:

  1. packageDataMining_CBA;
  2. importjava.text.MessageFormat;
  3. /**
  4. *CBA算法--基于关联规则的分类算法
  5. *@authorlyq
  6. *
  7. */
  8. publicclassClient{
  9. publicstaticvoidmain(String[]args){
  10. StringfilePath="C:\\Users\\lyq\\Desktop\\icon\\input.txt";
  11. StringattrDesc="Age=Senior,CreditRating=Fair";
  12. Stringclassification=null;
  13. //最小支持度阈值率
  14. doubleminSupportRate=0.2;
  15. //最小置信度阈值
  16. doubleminConf=0.7;
  17. CBATooltool=newCBATool(filePath,minSupportRate,minConf);
  18. classification=tool.CBAJudge(attrDesc);
  19. System.out.println(MessageFormat.format("{0}的关联分类结果为{1}",attrDesc,classification));
  20. }
  21. }
代码的结果为:

  1. 频繁1项集:
  2. {1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
  3. 频繁2项集:
  4. {1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
  5. 频繁3项集:
  6. {1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
  7. 频繁4项集:
  8. 频繁5项集:
  9. 频繁6项集:
  10. 频繁7项集:
  11. 频繁8项集:
  12. 频繁9项集:
  13. 频繁10项集:
  14. 频繁11项集:
  15. Age=Senior,CreditRating=Fair的关联分类结果为CLassYes
上面的有些项集为空说明没有此项集。Apriori算法类可以在 这里进行查阅,这里只展示了CBA算法的部分。

算法的分析

我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。

程序实现时遇到的问题

在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。

我对CBA算法的理解

CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值