算法细节见论文:Fast Algorithm for Mining Association Rules
图形化版本工程+测试用例下载戳这http://download.csdn.net/detail/michealtx/4266155
控制台版本C++代码如下:
#include <iostream>
#include <sstream>
#include <fstream>
#include <vector>
#include <set>
#include <map>
#include <ctime>
using namespace std;
//读取文件获取整个数据库存储在database中,fileName必须为char*型,要是用string会报错,in()不认
bool ObtainDatabase(vector<set<int> > &database,char *fileName)
{
/* set<int> data;
data.insert(1);data.insert(2);data.insert(5);
database.push_back(data);
data.clear();
data.insert(2);data.insert(4);
database.push_back(data);
data.clear();
data.insert(2);data.insert(3);
database.push_back(data);
data.clear();
data.insert(1);data.insert(2);data.insert(4);
database.push_back(data);
data.clear();
data.insert(1);data.insert(3);
database.push_back(data);
data.clear();
data.insert(2);data.insert(3);
database.push_back(data);
data.clear();
data.insert(1);data.insert(3);
database.push_back(data);
data.clear();
data.insert(1);data.insert(2);data.insert(3);data.insert(5);
database.push_back(data);
data.clear();
data.insert(1);data.insert(2);data.insert(3);
database.push_back(data);
*/
ifstream in(fileName);
if(!in)
{
cout<<"文件打开失败!"<<endl;
return false;
}
string s="";
unsigned int i=0;
while(getline(in,s))
{//读取一行记录
i++;
set<int> transaction;
int len=s.length();
string str="";
for(int i=0;i<len;i++)
{//将记录中的数提取出来
if(s[i]!=' ')
{
str+=s[i];
}
else if(s[i]==' '||i==len-1)
{
//字符串转int
stringstream stoi(str);
int item=0;
stoi>>item;
transaction.insert(item);
str="";
}
}
database.push_back(transaction);
s="";
}
cout<<i<<endl; //system("pause");
return true;
}
//遍历一遍数据库,创建1-项大项集
void CreateItemset(vector<set<int> >&database,vector<set<int> > &largeItemset,unsigned int minSupport,map<set<int>,int> &lm1)
{
map<int,int> dir;
map<int,int>::iterator dirIt;
vector<set<int> >::iterator databaseIt;
set<int> temp;
set<int>::iterator tempIt;
//根据数据库创建字典,字典形式为<item,count>
for(databaseIt=database.begin();databaseIt!=database.end();databaseIt++)
{
temp=*databaseIt;
for(tempIt=temp.begin();tempIt!=temp.end();tempIt++)
{
int item=*tempIt;
dirIt=dir.find(item);
if(dirIt==dir.end())
{//item不在字典dir中
dir.insert(pair<int,int>(item,1));
}
else
{//item在字典dir中,则将其count值加1
(dirIt->second)++;
}
}
}
//从字典中选出支持度超过minSopport的item
for(dirIt=dir.begin();dirIt!=dir.end();dirIt++)
{
if(dirIt->second>=minSupport)
{
set<int> large;
large.insert(dirIt->first);
largeItemset.push_back(large);
lm1.insert(pair<set<int>,int>(large,dirIt->second));
}
}
}
//输出大项集
void OutputLargeItemset(vector<set<int> > &largeItemset,unsigned int i)
{
cout<<"包含 "<<largeItemset.size()<<" 项的 "<<i<<"-项大项集:"<<endl;
vector<set<int> >::iterator largeItemsetIt;
int j=0;
for(largeItemsetIt=largeItemset.begin();largeItemsetIt!=largeItemset.end();largeItemsetIt++)
{
set<int> temp=*largeItemsetIt;
cout<<"{ ";
for(set<int>::iterator tempIt=temp.begin();tempIt!=temp.end();tempIt++)
{
cout<<(*tempIt)<<" ";
}
cout<<"}";
j++;
if(j%4==0)
{
cout<<endl;
}
}
cout<<endl<<endl;
}
//连接步骤,若it1和it2符合连接条件,则把它们连接为temp,返回true,否则返回false
bool Joint(set<int> &recordI,set<int> &recordJ,set<int> &temp)
{
if(recordI.size()!=recordJ.size())
{//俩集合大小不一样,立马返回!
return false;
}
set<int>::iterator it1=recordI.begin();
set<int>::iterator it2=recordJ.begin();
unsigned int size=recordI.size()-1;
for(int i=0;i<size;i++)
{
if(*it1!=*it2)
{
return false;
}
temp.insert(*it1);
it1++;
it2++;
}
if(*it1==*it2)
{
return false;
}
temp.insert(*it1);
temp.insert(*it2);
//cout<<"连接"<<*it1<<" "<<*it2<<endl;
return true;
}
//剪枝步骤,若temp的k-1项集有不在L[k-1]中,则剪掉,返回false,否则返回true
bool Prune(set<int> &temp,vector<set<int> > &largeTemp)
{
unsigned int size=temp.size();
//获取temp的全部k-1项子集,并判断每个子集是否在L[k-1]中
for(int i=0;i<size;i++)
{
set<int>::iterator tempIt=temp.begin();
set<int> tempMinusOne;//盛放k-1项子集
for(int j=0;j<size;j++)
{
if(j!=i)
{
tempMinusOne.insert(*tempIt);
}
*tempIt++;
}
//判断tempMinusOne是否在L[k-1]中
vector<set<int> >::iterator largeTempIt;
bool flag=false;//temp是否被剪掉的标识
for(largeTempIt=largeTemp.begin();largeTempIt!=largeTemp.end();largeTempIt++)
{//对大项集集合largeTemp中的大项集*largeTempIt逐个与tempMinusOne进行比对,看相不相同,相同就会保证flag=true,否则为false
flag=true;
set<int> large=*largeTempIt;
set<int>::iterator tempMinusOneIt=tempMinusOne.begin();
for(set<int>::iterator largeIt=large.begin();largeIt!=large.end();largeIt++)
{
if(*largeIt!=*tempMinusOneIt)
{
flag=false;
break;
}
tempMinusOneIt++;
}
if(flag==true)
{//存在了,不用再和其它大项集比较了,浪费时间
return true;
}
}
}
return false;
}
//利用L[k-1],通过连接和剪枝两个步骤,生成候选集集合candidate
void AprioriGen(vector<set<int> > &largeTemp,vector<set<int> > &candidate)
{
unsigned int largeTempSize=largeTemp.size();
unsigned int sizeTemp=largeTempSize-1;
vector<set<int> >::iterator largeTempIt=largeTemp.begin();
//L[k-1]中的大项集两两连接,求候选集集合
for(int i=0;i<sizeTemp;i++,largeTempIt++)
{//system("pause");cout<<largeTempSize<<" "<<i<<endl;
set<int> recordI=*largeTempIt;
for(int j=i+1;j<largeTempSize;j++)
{//cout<<j<<endl;
set<int> recordJ=*(largeTempIt+(j-i));
set<int> temp;
// cout<<"进行连接"<<endl;
if(Joint(recordI,recordJ,temp))
{//recordI和recordJ能连接成temp,则对temp进行剪枝
//cout<<"连接成功,进行剪枝"<<endl;
if(Prune(temp,largeTemp))
{//temp没有被剪掉,则把它加到候选集的集合中
if(!temp.empty())
// cout<<"temp不为空,没有被剪掉,成为到候选集"<<endl;
candidate.push_back(temp);
}
// else{cout<<"被剪掉了"<<endl;}
}
//else{cout<<"不符合连接条件"<<endl; }
}//system("pause");
}
}
//对比数据库中的每条交易,计算每个候选集的支持度,选出大于等于最小支持度的候选集来构成L[k]
void Subset(vector<set<int> > &database,vector<set<int> > &candidate,vector<set<int> > &largeK,unsigned int minSupport,map<set<int>,int> &lm)
{
vector<set<int> >::iterator databaseIt;
vector<set<int> >::iterator candidateIt;
for(candidateIt=candidate.begin();candidateIt!=candidate.end();candidateIt++)
{//对于每个候选集can
//bool cunzai=true;
set<int> can=*candidateIt;
//cout<<"cannnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn"<<endl;
unsigned int canCount=0;
for(databaseIt=database.begin();databaseIt!=database.end();databaseIt++)
{//对于数据库中每条交易,查看can是否在其中
set<int> data=*databaseIt;
if(can.size()>data.size())
{
continue;//候选集大小大于交易大小,肯定不在这个交易中
}
set<int>::iterator canIt;
for(canIt=can.begin();canIt!=can.end();canIt++)
{//对于can中每个项,看它是否在交易data中
if(data.find(*canIt)==data.end())
{
break;
}
}
if(canIt==can.end())
{//cout<<"在"<<endl;//system("pause");
canCount++;
//cout<<canCount<<endl;
}
}
if(canCount>=minSupport)
{//canCount只要大于等于最小支持度,我们就退出循环,不再对该候选集进行计数了,浪费时间
largeK.push_back(can);
lm.insert(pair<set<int>,int>(can,canCount));
}
}
}
int main(int argc,char *argv[])
{
char name[200];
string file="";
char *fileName="retail.dat";
int minSupport=5000;//最小支持度
/*
string ctl="";
cout<<"手动输入文件路径和最小支持度(Y/N)?";
cin>>ctl;
if(ctl=="Y"||ctl=="y")
{
cout<<"请依次输入文件路径和最小支持度,用空格隔开。(文件路径要用双斜杠):\n";
cin>>file>>minSupport;
strcpy(name,file.c_str());
fileName=name;
}
*/
vector<map<set<int>,int> > liss;
clock_t start=clock();
vector<set<int> > database;//数据库
ObtainDatabase(database,fileName);
vector<set<int> > large1;
map<set<int>,int> lm1;
CreateItemset(database,large1,minSupport,lm1);
liss.push_back(lm1);
int k=1;
vector<set<int> > largeTemp=large1;
while(!largeTemp.empty())
{
OutputLargeItemset(largeTemp,k);
k++;
vector<set<int> > candidate;
AprioriGen(largeTemp,candidate);
vector<set<int> > largeK;
map<set<int>,int> lm;
Subset(database,candidate,largeK,minSupport,lm);
largeTemp=largeK;
if(largeTemp.empty())
{
cout<<"L["<<k<<"]为空"<<endl;
}
else
{
liss.push_back(lm);
}
}
clock_t end=clock();
cout<<"Finish!共用时:"<<(end-start)<<"ms"<<endl;
system("pause");
}