apriori算法java实现,关联规则分析

一门课程的作业,原理就网上随便找找,文章只提供代码和测试数据。

代码:

package com.outsider.apriori;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

/**
 * 实现Apriori算法,并用其产生符合最小支持度的频繁项集,
 *              再根据频繁项集产生关联规则。
 * 
 * @author outsider
 */
public class Apriori {
    private final static int SUPPORT = 4; // 最小支持度
    private final static double CONFIDENCE = 0.6; // 最小置信度
    private final static String ITEM_SPLIT = ","; // 分隔符
    private final static String FILE_ADDRESS = "./data/weather.nominal.txt"; // 分隔符
    private final static String CON = "->";
    private Set<String> itemList;
    /**
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {
    	Apriori apriori = new Apriori();
        Map<String, Integer> tempFrequentSetMap = new HashMap<>();
        Map<String, Integer> frequentSetMap = new HashMap<>();
        Map<String, Double> associationRules = new HashMap<>();
        //1.读文件每一行到list中
        ArrayList<String[]> listFromFile = apriori.readFile(FILE_ADDRESS);
        //2.发现1项集
        tempFrequentSetMap = apriori.findFrequentOneSets(listFromFile);
        while (!tempFrequentSetMap.isEmpty()) {
            frequentSetMap.putAll(tempFrequentSetMap);
            tempFrequentSetMap = apriori.getCandidateSetMap(tempFrequentSetMap);
            tempFrequentSetMap = apriori.getFrequentSetMap(listFromFile,
                    tempFrequentSetMap);
        }
        //统计以下L(1),L(2),L(k)的个数
        Map<Integer, Integer> setNums = new HashMap<>();
        for(String set : frequentSetMap.keySet()) {
        	int len = set.split(ITEM_SPLIT).length;
        	Integer count = setNums.get(len);
        	if(count == null)
        		setNums.put(len, 1);
        	else
        		setNums.put(len, count+1);
        }
        System.out.println("各项集个数:");
        setNums.forEach((k,v)->System.out.println("L("+k+"):"+v));
        //获取关联规则
        associationRules = apriori.getAssociationRules(frequentSetMap);

        System.out.println("频繁项集:");
        List<String> list = new ArrayList<String>(frequentSetMap.keySet());
        Collections.sort(list);
        for (String string : list) {
            System.out.println(string + " : " + frequentSetMap.get(string));
        }
        
        System.out.println("关联规则: ");
        List<String> associationRulesList = new ArrayList<String>(associationRules.keySet());
        // Collections.sort(list);
        for (String string : associationRulesList) {
            System.out.println(string + " : " + associationRules.get(string));
        }
    }

    /**
     * -从文件读取数据,存入List中
     * 
     * @param fileAdd
     * @return arrayList
     * @throws IOException
     */
    private ArrayList<String[]> readFile(String fileAdd) throws IOException {
    	itemList = new HashSet<>();
        ArrayList<String[]> arrayList = new ArrayList<>();
        File file = new File(fileAdd);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            // 一次读入一行,直到读入null为文件结束
            while ((tempString = reader.readLine()) != null) {
            	String[] cols = tempString.split(ITEM_SPLIT);
                arrayList.add(cols);
                for(String col : cols) {
                	itemList.add(col);
                }
            }
            reader.close();
            return arrayList;
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return null;
    }

    /**
     * -发现频繁一项集
     * 
     * @param dataList
     * @return resultSetMap
     */
    private Map<String, Integer> findFrequentOneSets(
            ArrayList<String[]> dataList) {
    	Map<String, Integer> frequentOneSets = new HashMap<>();
    	for(String item : itemList) {
    		frequentOneSets.put(item, 0);
    	}
    	for(String[] sample : dataList) {
    		for(String column : sample) {
    			frequentOneSets.put(column, frequentOneSets.get(column)+1);
    		}
    	}
    	//筛选大于等于最小支持度
    	Iterator<Entry<String, Integer>> iter = frequentOneSets.entrySet().iterator();
    	while(iter.hasNext()) {
    		Entry<String, Integer> entry = iter.next();
    		if(entry.getValue() < SUPPORT)
    			iter.remove();
    	}
		return frequentOneSets;
    }

    /**
     * 获取所有候选频繁项集,包含连接步和剪枝步
     * 
     * @param inputMap
     * @return candidateSetMap
     */
    private Map<String, Integer> getCandidateSetMap(
            Map<String, Integer> inputMap) {
    	//连接上一个k项集产生k+1项集
    	Set<String> kitems = new HashSet<>();
    	Map<String, Integer> result = new HashMap<>();
    	Set<String> keys = inputMap.keySet();
    	//k项集的k=kSetLen
    	int kSetLen = keys.iterator().next().split(ITEM_SPLIT).length + 1;
    	//自身连接,每2个连接成一个组合,重复的去掉
    	for(String key : keys) {
    		for(String key2 : keys) {
    			if(key.equals(key2))
    				continue;
    			String[] i1 = key.split(ITEM_SPLIT);
    			String[] i2 = key2.split(ITEM_SPLIT);
    			Set<String> i3 = new HashSet<>();
    			for(String ii1 : i1)
    				i3.add(ii1);
    			for(String ii2 : i2)
    				i3.add(ii2);
    			if(i3.size() != kSetLen)
    				continue;
    			List<String> i3List = new ArrayList<>(i3.size());
    			for(String ii3 : i3)
    				i3List.add(ii3);
    			Collections.sort(i3List);
    			String[] s = new String[i3List.size()];
    			i3List.toArray(s);
    			kitems.add(String.join(ITEM_SPLIT, s));
    		}
    	}
    	for(String s : kitems)
    		result.put(s, 0);
    	return result;
    }

    /**
     * -根具候选项集获取满足最小支持度的频繁项集
     * 
     * @param inputList
     * @param inputMap
     * @return
     */
    private Map<String, Integer> getFrequentSetMap(ArrayList<String[]> inputList,
            Map<String, Integer> inputMap) {

        int flog = 0;
        List<String> list = new ArrayList<>();
        Set<String> keySet = new HashSet<>();
        keySet.addAll(inputMap.keySet());

        for (String[] strings : inputList) {
            //String[] strings = data.split(ITEM_SPLIT);
            for (int i = 0; i < strings.length; i++) {
                list.add(strings[i]);
            }

            for (String string : keySet) {
                String[] keyItem = string.split(ITEM_SPLIT);
                flog = keyItem.length;
                for (String string2 : keyItem) {
                    if (list.contains(string2)) {
                        --flog;
                    }
                }

                if (flog == 0) {
                    inputMap.put(string, inputMap.get(string) + 1);
                }
            }

            list.clear();
        }

        for (String string : keySet) {
            if (inputMap.get(string) < SUPPORT) {
                inputMap.remove(string);
            }
        }
        return inputMap;
    }

    /**
     * -根据频繁项集获取关联规则
     * 
     * @param frequentSetMap
     * @return
     */
    private Map<String, Double> getAssociationRules(
            Map<String, Integer> frequentSetMap) {
    	//从k>=2项集开始筛选
    	//选择一个item作为结果,其余的作为前置条件
    	Map<String, Double> rules = new HashMap<>();
    	for(String frequentSet : frequentSetMap.keySet()) {
    		String[] items = frequentSet.split(ITEM_SPLIT);
    		if(items.length >= 2) {
    			for(int i = 0; i < items.length; i++) {
    				String res = items[i];
    				String[] carr = new String[items.length -1];
    				int count = 0;
    				for(int j = 0; j < items.length; j++) {
    					if(j!=i) carr[count++] = items[j];
    				}
    				//产生关联规则 xx->res
    				String condition = String.join(ITEM_SPLIT, carr);
    				int f = frequentSetMap.get(condition);
    				double confidence  = frequentSetMap.get(frequentSet)*1.0 / f;
    				if(confidence >= CONFIDENCE)
    					rules.put(condition+CON+res, confidence);
    			}
    		}
    	}
    	return rules;
    }
}

数据:

sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

 

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值