刚开始接触Java,再加上学了《数据挖掘》,就用Java实现了Apriori关联规则算法,花了我今天一天的时间,眼睛都看酸了。
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
public class Apriori {
//产生候选集
public ArrayList<Item> CreateCandidateSet(ArrayList<Item> FrequentItemSet, String DataSet[][], boolean isFirst)
{
ArrayList<Item> CandidateSet = new ArrayList<Item>();
if(isFirst)
{
for(int i = 0; i < DataSet.length; i ++)
{
for(int j = 0; j < DataSet[i].length; j ++)
{
Item item = new Item(DataSet[i][j], 1);
if(CandidateSet.contains(item))
{
CandidateSet.get(CandidateSet.indexOf(item)).SupportCount ++;
}
else
{
CandidateSet.add(item);
}
}
}
}
else
{
for(int i = 0; i < FrequentItemSet.size()-1; i ++)
{
for(int j = i+1; j < FrequentItemSet.size(); j ++)
{
ArrayList<String> list = Union(StrToList(FrequentItemSet.get(i).Element),
StrToList(FrequentItemSet.get(j).Element));
if(list.size() == FrequentItemSet.get(0).Element.length+1)
{
String [] UnionStr = list.toArray(new String[list.size()]);
if(!CandidateSet.contains(new Item(UnionStr)))
{
int SupCnt = CalSupCnt(DataSet, UnionStr);
CandidateSet.add(new Item(UnionStr, SupCnt));
}
}
}
}
}
return CandidateSet;
}
//产生频繁项集
public ArrayList<Item> CreateFrequentItemSet(ArrayList<Item> CandidateSet, int MinSupCnt)
{
for(int i = 0; i < CandidateSet.size(); i ++)
{
if(CandidateSet.get(i).SupportCount < MinSupCnt)
{
CandidateSet.remove(i--); //注意remove后i不能自增,故i--
}
}
return CandidateSet;
}
public int CalSupCnt(String[][] DataSet, String[] Str) //计算支持度计数
{
int SupCnt = 0;
int sum;
int i, j, k;
boolean isFound;
for(i = 0; i < DataSet.length; i ++)
{
for(j = 0, sum = 0, isFound = true; j < Str.length && isFound; j ++)
{
isFound = false;
for(k = 0; k < DataSet[i].length; k ++)
{
if(DataSet[i][k].equals(Str[j]))
{
sum ++;
isFound = true; break;
}
}
}
if(sum == Str.length)
{
SupCnt ++;
}
}
return SupCnt;
}
public ArrayList<String> StrToList(String[] DataSet) //String加入到ArrayList
{
ArrayList<String> list = new ArrayList<String>();
for(int i = 0; i < DataSet.length; i ++)
{
list.add(DataSet[i]);
}
return list;
}
//求并集
public ArrayList<String> Union(ArrayList<String> list1, ArrayList<String> list2)
{
list1.removeAll(list2); //ArrayList的retaiaAll方法是求交集的
list1.addAll(list2);
return list1;
}
public void PrintFrequentItemSet(ArrayList<Item> FrequentItemSet) //输出频繁项集
{
System.out.println("最大频繁项集:");
for(int i = 0; i < FrequentItemSet.size(); i ++)
{
ArrayList<String> list = StrToList(FrequentItemSet.get(i).Element);
Collections.sort(list);
System.out.println(list.toString() + "\tSupportCount = " + FrequentItemSet.get(i).SupportCount);
}
}
public void PrintAssociationRules(String[] AssociationRules) //输出关联规则
{
System.out.println("\n强关联规则:");
for(int i = 0; i < AssociationRules.length; i ++)
{
System.out.println(AssociationRules[i]);
}
}
//得到最大频繁项集
@SuppressWarnings("unchecked")
public ArrayList<Item> GetMaxFrequentItemSet(String[][] DataSet, int MinSupCnt)
{
boolean isFirst = true;
ArrayList<Item> CandidateSet = new ArrayList<Item>();
ArrayList<Item> FrequentItemSet = new ArrayList<Item>();
ArrayList<Item> MaxFrequentItemSet = new ArrayList<Item>();
do
{
if(!isFirst) MaxFrequentItemSet = (ArrayList<Item>)FrequentItemSet.clone();
CandidateSet = CreateCandidateSet(FrequentItemSet, DataSet, isFirst);
FrequentItemSet = CreateFrequentItemSet(CandidateSet, MinSupCnt);
isFirst = false;
}while(!FrequentItemSet.isEmpty());
return MaxFrequentItemSet;
}
//产生关联规则
public String[] CreateAssociationRules(ArrayList<Item> MaxFrequentItemSet, String[][] DataSet, double MinConf)
{
ArrayList<String> list = new ArrayList<String>();
for(int i = 0; i < MaxFrequentItemSet.size(); i ++)
{
int SupCnt = MaxFrequentItemSet.get(i).SupportCount;
String[] Element = MaxFrequentItemSet.get(i).Element;
for(int j = 1; j < (1 << Element.length)-1; j ++) //遍历所有非空真子集
{
ArrayList<String> list1 = new ArrayList<String>();
ArrayList<String> list2 = new ArrayList<String>();
for(int k = 0; k < Element.length; k ++)
{
if((j >> k & 1) == 1)
{
list1.add(Element[k]);
}
else
{
list2.add(Element[k]);
}
}
String[] Str = list1.toArray(new String[list1.size()]);
int SubSupCnt = CalSupCnt(DataSet, Str);
double Conf = (double)SupCnt / SubSupCnt;
if(Conf >= MinConf)
{
DecimalFormat DecFmt = new DecimalFormat("#");
Collections.sort(list1);
Collections.sort(list2);
String Rule = list1.toString() + " ==> " + list2.toString()
+ "\tConfidence = " + String.valueOf(DecFmt.format(100*Conf)) + "%";
list.add(Rule);
}
}
}
return list.toArray(new String[list.size()]);
}
public static void main(String[] args)
{
//数据集
String DataSet[][] = {
{"I1", "I2", "I5"},
{"I2", "I4"},
{"I2", "I3"},
{"I1", "I2", "I4"},
{"I1", "I3"},
{"I2", "I3"},
{"I1", "I3"},
{"I1", "I2", "I3", "I5"},
{"I1", "I2", "I3"}};
int MinSupCnt = 2; //最小支持度计数
double MinConf = 0.3; //最小可信度
Apriori apriori = new Apriori();
ArrayList<Item> MaxFrequentItemSet = apriori.GetMaxFrequentItemSet(DataSet, MinSupCnt);
apriori.PrintFrequentItemSet(MaxFrequentItemSet);
String[] AssociationRules = apriori.CreateAssociationRules(MaxFrequentItemSet, DataSet, MinConf);
apriori.PrintAssociationRules(AssociationRules);
}
}
import java.util.ArrayList;
import java.util.Collections;
class Item{
public String[] Element;
public int SupportCount;
public Item(String[] Element, int SupportCount)
{
this.Element = new String[Element.length];
this.Element = Element;
this.SupportCount = SupportCount;
}
public Item(String[] Element)
{
this.Element = new String[Element.length];
this.Element = Element;
this.SupportCount = 0;
}
public Item(String Element, int SupportCount)
{
this.Element = new String[1];
this.Element[0] = Element;
this.SupportCount = SupportCount;
}
public boolean equals(Object obj) //重载equals, 用于ArrayList的contains和indexOf方法
{
Item item = (Item)obj;
if(item.Element.length == 1)
{
return this.Element[0].equals(item.Element[0]);
}
Apriori apriori = new Apriori();
ArrayList<String> list1 = apriori.StrToList(this.Element);
ArrayList<String> list2 = apriori.StrToList(item.Element);
Collections.sort(list1);
Collections.sort(list2);
return list1.equals(list2);
}
}
运行结果如下图所示: