机器学习入门算法及其java实现-Apriori(文本关联性)算法

算法原理:

基本概念介绍:

  • 支持度:
    对于事件AB的支持度 support=P(AB)
  • 置信度:
    置信度confidence=P(B|A)=P(AB)/P(A)
    3、强关联规则:
    如果存在一条关联规则,它的支持度和置信度都大于预先定义好的最小支持度与置信度,我们就称它为强关联规则。强关联规则就可以用来了解项之间的隐藏关系。所以关联分析的主要目的就是为了寻找强关联规则,而Apriori算法则主要用来帮助寻找强关联规则。
    4、频繁项
    在多个集合中,频繁出现的元素/项,就是频繁项。
    5、频繁项集
    有一系列集合,这些集合有些相同的元素,集合中同时出现频率高的元素形成一个子集,满足一定阈值条件,就是频繁项集。
    6、极大频繁项集:
    元素个数最多的频繁项集合,即其任何超集都是非频繁项集。
    7、相似性分析:
    研究的对象是集合之间的相似性关系。而频繁项集分析,研究的集合间重复性高的元素子集。
    1.2、Apriori算法描述:
    对于寻找强关联规则,首先需要找到频繁集。对于频繁集,我们首先遍历数据集D,遍历它的每一条记录T,得到T的所有子集,然后计算每一个子集的支持度,得到的结果与最小支持度比较。由于频繁集的子集也一定是频繁集,非频繁集的超集一定是非频繁集,所以我们可以使用逐层搜索的方法。
    第一步:第一轮候选集是数据集D中的项,而其它轮次候选集则是有前一轮次频繁集自连接得到的频繁集(频繁集由候选集剪枝得到);
    第二步:对于候选集进行剪枝。对于候选集的每一条记录T,如果它的支持度小于最小支持度就会被剪掉;
    终止条件:如果自连接得到的不再是频繁集或者自连接得到结果是它自身,那么就取最后一次得到的频繁集作为结果。
    而对于强规则,对于一个频繁集{1,2,3},我们可以得到它的子集:{1},{2},{3},{1,2},{1,3},{2,3}。
    我们能够得到规则:123213321123132231 ,根据这些规则,我么可以得到其置信度,与最小置信度相比较,我们就可以得到强规则。
package apriori;

import java.io.IOException;
import java.util.ArrayList;


public class Apriorimain {
    public static void main(String[] args) throws IOException {
        InputStringData ori=new InputStringData();
        String[][] shopping=ori.loadData("购物数据.txt");
        for(int i=0;i<shopping.length;i++){
            for(int j=0;j<shopping[i].length;j++){
                System.out.print(shopping[i][j]+" ");
            }
            System.out.println(" ");
        }
        System.out.println(" ");
//打印出源数据
        Apriori tt=new Apriori(shopping);
        tt.node(shopping, tt.support);
        ArrayList<String[][]> frequence=tt.frequences();
        ArrayList<double[]> frequencevalue=tt.frequencesvalue();
        int pp=frequence.size();
        double ppt=tt.confidence;
        for(int k=0;k<pp;k++){
            String[][] c=frequence.get(k);
            double[] d=frequencevalue.get(k);
            System.out.println("L"+(k+1)+"阶频繁项:");
            for(int i=0;i<c.length;i++){
                for(int j=0;j<c[i].length;j++){
                    System.out.print(c[i][j]+" ");
                    }
                System.out.print("支持度:"+d[i]);
                System.out.println(" ");
                }
            System.out.println(" ");

            //输出Lk阶频繁项
            }  

        String[][] kth=frequence.get(pp-1);
        for(int k=0;k<pp-1;k++){
            String[][] c=frequence.get(k);
            double[] d=frequencevalue.get(k);
            System.out.println("L"+(k+1)+"阶强规则:");
            for(int i=0;i<c.length;i++){
                for(int k1=0;k1<kth.length;k1++){
                    if(tt.exist(kth[k1],c[i])){
                        if(frequencevalue.get(pp-1)[k1]/frequencevalue.get(k)[i]>=ppt){
                        for(int j=0;j<c[0].length;j++){
                            System.out.print(c[i][j]+" ");
                        }
                        System.out.print("->");
                        for(int t1=0;t1<kth[0].length;t1++){
                            System.out.print(kth[k1][t1]+" ");
                        }
                        System.out.print("置信度:"+frequencevalue.get(pp-1)[k1]/
frequencevalue.get(k)[i]);
                        System.out.println(" ");
                        }
                }
            }
        }
            System.out.println(" ");
        }
    }
    //计算强规则并打印
}

package apriori;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class InputStringData {
           int countRow=0,countCol=0,temp=0;
    public  String[][] loadData(String trainfile)throws IOException{
           ArrayList<String>features = new ArrayList<String>();
           File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
           Scanner input1 = new Scanner(file);
           while(input1.hasNext()){
               String line = input1.nextLine();
               Scanner input2 = new Scanner(line);
               countRow++;input2.close();
           }
           Scanner input11 = new Scanner(file);
           int[] length=new int[countRow];
           int i=0;
           while(input11.hasNext()){
               String line = input11.nextLine();
               Scanner input2 = new Scanner(line);
               temp=0;
               while(input2.hasNext()){
               features.add(input2.next());
               temp++;
               }
               if(countCol<temp){
                   countCol=temp;
               }
               length[i]=temp;
               i++;
               input2.close();
           }
           input11.close();
           String [][]x = new String[countRow][countCol];
           int index=0;
           for(int i1=0;i1<countRow;i1++){
               for(int j=0;j<countCol;j++){
                   if(length[i1]<=j){
                       x[i1][j]="null";
                   }
                   else{
                       x[i1][j]=features.get(index);
                       index++;
                   }
               }
           }
   return x;
}
}
//输入原始数据

package apriori;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

public class Apriori {
    String[][]support;
    String[][]frequence;
    double[]frequencevalue;
    double supportdegree=0.2; 
    double confidence=0.8;
    ArrayList<String[][]> temp11=new ArrayList<String[][]>();
    ArrayList<double[]> temp111=new ArrayList<double[]>();
    public Apriori(String[][]a){
        Set<String> ori=new HashSet<String>();
        for(int i=0;i<a.length;i++){
            for(int j=0;j<a[i].length;j++){
                if(!a[i][j].equals("null")){
                    ori.add(a[i][j]);
                }
            }
        }
        support=new String[ori.size()][1];
        double[]d=new double[ori.size()];
        Iterator<String> set=ori.iterator();
        for(int i=0;i<support.length;i++){
            String[] d1=new String[1];
            d1[0]=set.next();
            for(int j=0;j<support[0].length;j++){
                if(ratio(a,d1)>=supportdegree){
                    support[i][j]=d1[j];
                }
                d[i]=ratio(a,d1);
            }
        }
        temp11.add(support);
        temp111.add(d);
    }
    //提取源数据元素并存入support数据框
    public void node(String[][]a, String[][]b){
        //a为原始数据,b为Lk-1阶频繁项集
        if(b.length==1){
            return;
        }
        if(b[0][0].equals("null")){
            return;
        }
        ArrayList<String[]> temp= new ArrayList<String[]>();
        ArrayList<Double> temp1= new ArrayList<Double>();
        for (int i=0;i<b.length-1;i++){
            for(int j=i+1;j<b.length;j++){
                if(seemclose(b[i],b[j])){
                    temp.add(conbine(b[i],b[j]));
                }
            }
        }
        //找到所有k阶项集并存入temp
        int index1=0;
        while(index1<temp.size()){
            int index2=index1+1;
            while(index2<temp.size()){
                if(isSame(temp.get(index1),temp.get(index2))){
                temp.remove(index2);
                index2--;
                }
                index2++;
            }
            //删除temp中相同的项
        int index=0;
        while(index<temp.size()){
            String[]temp11=temp.get(index);
            double t=ratio(a,temp11);
            if(t<supportdegree){
                temp.remove(index);
                index--;
            }
            else{
                temp1.add(t);
            }
            index++;
        }
        //删除temp中非频繁项集

            index1++;
        }
        if(temp.size()>0){
        frequence=new String[temp.size()][b[0].length+1];
        frequencevalue=new double[temp.size()];
        for(int i=0;i<frequence.length;i++){
            frequence[i]=temp.get(i);
            frequencevalue[i]=temp1.get(i);
            }
        }
        else{
            frequence=new String[1][1];
            frequence[0][0]="null";
            return;
        }
        temp11.add(frequence);
        temp111.add(frequencevalue);
        //将k阶频繁项集输入temp11
        node(a, frequence);
        //将Lk-1阶频繁项集输入叶子节点再次进行计算
    }
    public boolean isSame(String[] a, String[] b) {
        int k=0;
        boolean cjh=false;
        for(int i=0;i<a.length;i++){
            if(b[i].equals(a[i])){
            k++;
            }
        }
        if(k==b.length){
            cjh=true;
        }
        return cjh;
    }
    //检测a,b两个字符是否相同
    public double ratio(String[][] a, String[] b) {
        double temp=0;
        for(int i=0;i<a.length;i++){
            if(exist(a[i],b)){
                temp++;
            }
        }
        temp=temp/(double)a.length;
        return temp;
    }
    //计算源数据a中b的支持度
    public boolean exist(String[] a, String[] b) {
        boolean temp=false;
        int k=0;
        for(int i=0;i<b.length;i++){
            for(int j=0;j<a.length;j++){
                if(b[i].equals(a[j])){
                    k++;
                    break;
                }
            }
        }
        if(k==b.length){
            temp=true;
        }
        return temp;
    }
    //检测b数组是否在a数组中同时存在
    public String[] conbine(String[] a, String[] b) {
        Set<String> doknown=new HashSet<String>();
        for(int i=0;i<a.length;i++){
            doknown.add(a[i]);
            doknown.add(b[i]);
        }
        String[] c=new String[doknown.size()];
        Iterator<String> set=doknown.iterator();
        for(int i=0;i<c.length;i++){
            c[i]=set.next();
        }
        return c;
    }
    //将Lk-1阶项合并为Lk阶项
    public boolean seemclose(String[] a, String[] b) {
        boolean wantYou=false;
        int k=0;
        for(int i=0;i<a.length;i++){
            for(int j=i;j<a.length;j++){
                if(b[j]==a[i]){
                    k++;
                    break;
                }
            }
        }
        if(k==b.length-1){
            wantYou=true;
        }
        return wantYou;
    }
    //检测Lk-1阶项是否满足合并条件
    public ArrayList<String[][]> frequences() {
        ArrayList<String[][]> frequences=new ArrayList<String[][]>();
        frequences=temp11;
        return frequences;
    }
    //提出所有的Lk阶项集
    public ArrayList<double[]> frequencesvalue() {
        ArrayList<double[]> frequencesvalue=new ArrayList<double[]>();
        frequencesvalue=temp111;
        return frequencesvalue;
    }
    //提出所有支持度值
}

输入

i1 i2 i4 null
i1 i2 i3 null
i2 i3 null null
i1 i3 null null
i1 i2 i3 i5
i2 i3 null null
i1 i2 i5 null
i2 i4 null null
i1 i2 i3 null
i1 i3 null null

结果:

L1阶频繁项:
i1 支持度:0.7
i2 支持度:0.8
i3 支持度:0.7
i4 支持度:0.2
i5 支持度:0.2

L2阶频繁项:
i1 i2 支持度:0.5
i1 i3 支持度:0.5
i1 i5 支持度:0.2
i2 i3 支持度:0.5
i2 i4 支持度:0.2
i2 i5 支持度:0.2

L3阶频繁项:
i1 i2 i3 支持度:0.3
i1 i2 i5 支持度:0.2

L1阶强规则:
i5 ->i1 i2 i5 置信度:1.0

L2阶强规则:
i1 i5 ->i1 i2 i5 置信度:1.0
i2 i5 ->i1 i2 i5 置信度:1.0

发布了7 篇原创文章 · 获赞 3 · 访问量 5945
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览