apriori算法 java实现

package com.sduept.bigdata.ml.apriori.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Predicate;

import com.sduept.bigdata.ml.apriori.Apriori;
import com.sduept.ice.utils.ConsoleTable;
import com.sduept.ice.utils.TransformDataUtils;

/**
 * 关联分析
 * 
 * @author xuqinwen
 */
public class AprioriModel {
    
	
	private double minSupport;
	private double minConfidence;
	private String path;
	
	public static HashMap<List<String>, Integer> finlFreq = new HashMap<List<String>, Integer>();
	public static List<List<String>> associationRules = new ArrayList<>();


	
	/**
	 * 展示频繁项集
	 */
	public static  void freqItemsetsShow() {
		
		ConsoleTable table = new ConsoleTable(2, true);
		table.appendRow();
		table.appendColum("items").appendColum("freq");
		
		for(Map.Entry<List<String>, Integer> entry:finlFreq.entrySet()) {
			StringBuilder builder = new StringBuilder();
			 for (int i = 0; i < entry.getKey().size(); i++) {
				 
				 if(i==entry.getKey().size()-1) {
					 builder.append(entry.getKey().get(i));
				 }else {
				builder.append(entry.getKey().get(i)+",");  
				 }
			}
			 
			 table.appendRow();
			 table.appendColum(builder.toString()).appendColum(entry.getValue().toString()); 
			 
		}
		System.out.println(table.toString());
	}
	
	
	/**
	 * 关联规则展示
	 */
	public static void associationRulesShow() {
		
		ConsoleTable table = new ConsoleTable(4, true);
		table.appendRow();
		table.appendColum("antecedent").appendColum("consequent").appendColum("condfidence").appendColum("lift");
		for (int i=0;i<associationRules.size();i++) {
			StringBuilder builder = new StringBuilder();
			int liftSize = associationRules.get(i).size()-1;
			int confidenceSize = associationRules.get(i).size()-2;
			int consequentSize = associationRules.get(i).size()-3;
			for (int j =associationRules.get(i).size()-4; j >=0; j--) {
				if(j==0) {
					builder.append(associationRules.get(i).get(j));
				}else {
					builder.append(associationRules.get(i).get(j)+",");
				}
				
			}
			 table.appendRow();
			 table.appendColum(builder.toString()).appendColum(associationRules.get(i).get(consequentSize)).appendColum(associationRules.get(i).get(confidenceSize)).appendColum(associationRules.get(i).get(liftSize)); 
		}
		
		System.out.println(table.toString());
		
	}
	
	public AprioriModel(double minSupport, double minConfidence, String path) throws Exception {
		List<List<String>> dataFrame = getDataFrame(path);
		HashMap<List<String>, Integer> freqItemsets = freqItemsets(dataFrame, minSupport, path);
		associationRules(freqItemsets, dataFrame, minConfidence, minSupport, path);
		
	}

	
    //后续废弃掉这个方法
	public static HashMap<List<String>, Integer> getFinlFreq(String path, List<List<String>> dataFrame,
			double minSupport, String keys) throws Exception {

		freqItemsets(dataFrame, minSupport, path);

		return finlFreq;

	}

	private List<List<String>> getDataFrame(String path) throws Exception {
		List<List<String>> dataFrame = TransformDataUtils.dataFrame(path);
		return dataFrame;
	}

	// 频繁项集的获取
	public static HashMap<List<String>, Integer> freqItemsets(List<List<String>> dataFrame, double minSupport,
			String path) throws Exception {
		HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
		// List<List<String>> dataFrame1 = TransformDataUtils.dataFrame(path);
		// HashMap<String, Integer> freqMap = new HashMap<String, Integer>();//
		// 存放符合条件的频繁项集 key为项 value为数目
		int rowNum = TransformDataUtils.getLine(path);
		List<String> everyData = new ArrayList<>();// 放每一个项 使得对项计数时不要重复
		double minSupportUse = rowNum * minSupport;// 下面的判断条件 提取出满足条件的
		int count_j = 0;
		int count_i = 0;
		String data = null;
		for (int i = count_i; i < dataFrame.size();) {
			for (int j = count_j; j < dataFrame.get(i).size();) {
				data = dataFrame.get(i).get(j);
				break;
			}
			if (!everyData.contains(data)) {
				everyData.add(data);
				int freqNum = 0;
				for (int k = 0; k < dataFrame.size(); k++) {
					if (dataFrame.get(k).contains(data)) {
						freqNum++;
					}
				}
				List<String> dataList = new ArrayList<>();
				dataList.add(data);
				if (freqNum >= minSupportUse) {
					finlFreq.put(dataList, freqNum);
				}
			}
			count_j++;
			if (count_j == dataFrame.get(i).size()) {
				i++;
				count_j = 0;
			}
		}
		
		
		if (finlFreq.size() >= 2) {

			HashMap<List<String>, Integer> freqMap = groupFreq(finlFreq, dataFrame, minSupportUse);
			finlFreq.putAll(freqMap);
			if (freqMap.size() >= 2) {
				iteratorFreq(freqMap, dataFrame, minSupportUse);
			}
			// freqMaps.putAll(thirdMap);
		}
		return finlFreq;

	}

	
	/**
	 * 对每一次的频繁项集进行组合 再检索
	 * 
	 * @param freqMap       第一阶段的频繁项集
	 * @param dataFrame     原数据集
	 * @param minSupportUse 计算过的最小支持度
	 * @throws Exception
	 */
	public static HashMap<List<String>, Integer> groupFreq(HashMap<List<String>, Integer> freqMap,
			List<List<String>> dataFrame, double minSupportUse) throws Exception {
		HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
		List<List<String>> nextFreq = new ArrayList<>(); // 放入下一阶段的项,由上一阶段频繁项集形成
		List<String> freq = new ArrayList<>();
		for (Entry<List<String>, Integer> entry : freqMap.entrySet()) {
			freq.add(entry.getKey().get(0));
		}
		
		int count_i = 0;
		int count_j = count_i + 1;
        for (int i = count_i; i < freq.size() - 1;) {
			for (int j = count_j; j < freq.size();) {
				List<String> groupFreq = new ArrayList<>();// 放入组合之后的项
				groupFreq.add(freq.get(i));
				groupFreq.add(freq.get(j));
				nextFreq.add(groupFreq);
				break;
			}
			count_j++;
			if (count_j == freq.size()) {
				i++;
				count_j = i + 1;
			}
		}
		for (int i = 0; i < nextFreq.size(); i++) {
			int freqNum = 0;// 用于计数
			List<String> everyData = new ArrayList<>();
			everyData = nextFreq.get(i);
			for (int j = 0; j < dataFrame.size(); j++) {
				if (dataFrame.get(j).containsAll(everyData)) {
					freqNum++;
				}
			}
			if (freqNum >= minSupportUse) {
				freqMaps.put(everyData, freqNum);
			}
		}
	
		return freqMaps;
	}

	/**
	 *
	 * @param freqMap       第二步的频繁项集
	 * @param dataFrame     原数据集
	 * @param minSupportUse 计算过的最小支持度
	 * @return
	 */
	public static void iteratorFreq(HashMap<List<String>, Integer> freqMap, List<List<String>> dataFrame,
			double minSupportUse) {

		List<List<String>> nextFreq = new ArrayList<>(); 
		for (Map.Entry<List<String>, Integer> entry : freqMap.entrySet()) {
			nextFreq.add(entry.getKey());
		}
		List<List<String>> thirdFreq = new ArrayList<>();// 放入第二阶段组合成的频繁项集
		int count_i = 0;
		int count_j = count_i + 1;
		for (int i = count_i; i < nextFreq.size() - 1;) {
			Set<String> freqSet = new HashSet<>();// 经过set去重
			for (int j = count_j; j < nextFreq.size();) {

				for (int j2 = 0; j2 < nextFreq.get(i).size(); j2++) {
					freqSet.add(nextFreq.get(i).get(j2));
				}
				for (int j2 = 0; j2 < nextFreq.get(j).size(); j2++) {
					freqSet.add(nextFreq.get(j).get(j2));
				}
				List<String> freq = new ArrayList<>();
				freq.addAll(freqSet);
				thirdFreq.add(freq);// 组合后的有可能还是有重的 应该继续使用set
				break;
			}
			count_j++;
			if (count_j == nextFreq.size()) {
				i++;
				count_j = i + 1;
			}
		}
		HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
		for (int i = 0; i < thirdFreq.size(); i++) {
			int freqNum = 0;// 用于计数
			List<String> everyData = new ArrayList<>();
			everyData = thirdFreq.get(i);
			for (int j = 0; j < dataFrame.size(); j++) {
				if (dataFrame.get(j).containsAll(everyData)) {
					freqNum++;
				}
			}
			if (freqNum >= minSupportUse) {
				finlFreq.put(everyData, freqNum);
				freqMaps.put(everyData, freqNum);
			}
		}
		if (freqMaps.size() >= 2) {
			iteratorFreq(freqMaps, dataFrame, minSupportUse);
		}
	}

	/**
	 * 求出关联规则 置信度和作用度
	 * 
	 * @param freqItemsets  频繁项集
	 * @param dataFrame     原数据集
	 * @param minConfidence 最小置信度
	 * @param minSupport    最小支持度
	 * @return
	 * @throws Exception
	 */
	public List<List<String>> associationRules(HashMap<List<String>, Integer> freqItemsets,
			List<List<String>> dataFrame, double minConfidence, double minSupport, String path) throws Exception {
		List<List<String>> allFreq = new ArrayList<>();
		List<String> single = new ArrayList<>();
		for (Map.Entry<List<String>, Integer> entry : freqItemsets.entrySet()) {
			allFreq.add(entry.getKey());
			if (entry.getKey().size() == 1) {
				single.add(entry.getKey().get(0));
			}
		}
		///List<List<String>> confidenceFreq = new ArrayList<>();
		for (int i = 0; i < allFreq.size(); i++) {

			List<String> group = new ArrayList<>();
			for (int j = 0; j < allFreq.get(i).size(); j++) {
				group.add(allFreq.get(i).get(j));
			}
			double antecedentNum = 0; 
			// int consequentNum = 0; 
			int sum = TransformDataUtils.getLine(path);
			double minSupportUse = (sum*minSupport);
			// int expectConfidenceNum = 0;
			for (int j = 0; j < dataFrame.size(); j++) {
				if (dataFrame.get(j).containsAll(group)) {
					antecedentNum++;
				}
			}
			for (int j = 0; j < single.size(); j++) {// 不满足条件 将其从集合移除 满足条件 加入满足置信度的集合并移除
				double consequentNum = 0; // 随之发生的 随之发生的要满足最小支持度
				// int sum = TransformDataUtils .getLine(path);//数据的条目
				double expectConfidenceNum = 0;// 单项情况发生的次数
				if (group.size() == allFreq.get(i).size() + 1) {
					group.remove(allFreq.get(i).size());// 移除掉随之发生的单量频繁项集
				} else if (group.size() > allFreq.get(i).size() + 1) {
					group.remove(allFreq.get(i).size() + 2);// 移除掉支持度
					group.remove(allFreq.get(i).size() + 1);// 移除掉置信度
					group.remove(allFreq.get(i).size());// 移除掉随之发生的单量频繁项集
				}
				if (!group.contains(single.get(j))) {// group.get(0).equals(single.get(j)) &&

					group.add(single.get(j));
					for (int j1 = 0; j1 < dataFrame.size(); j1++) {
						if (dataFrame.get(j1).containsAll(group)) {
							consequentNum++;
						}
						if (dataFrame.get(j1).contains(single.get(j))) {
							expectConfidenceNum++;
						}

					}
					if (consequentNum >= minSupportUse) {

						double confidence = consequentNum / antecedentNum;// 计算出来的置信度
						if (confidence >= minConfidence) {
							double expectConfidence = expectConfidenceNum / sum;// 期望可信度
							double lift = (confidence / expectConfidence);// 作用度
							group.add(String.valueOf(confidence));// 将满足条件的置信度加上
							group.add(String.valueOf(lift)); // 作用度加上
							
							List<String> groupUse = new ArrayList<>();
							groupUse.addAll(group);
							associationRules.add(groupUse);
						}
					}
				}
			} // single的循环
		}
		return associationRules;
	}

	public String[][] lift(String[][] associationRules, String where) {
		// TODO Auto-generated method stub
		return null;
	}


	public HashMap<List<String>, Integer> freqItemsets(List<List<String>> dataFrame, double minSupport, String path,String test)
			throws Exception {
		// List<List<String>> dataFrame1 = TransformDataUtils.dataFrame(path);
		HashMap<List<String>, Integer> freqMap = new HashMap<List<String>, Integer>();
		int rowNum = TransformDataUtils.getLine(path);
		List<String> everyData = new ArrayList<>();// 放每一个项 使得对项计数时不要重复
		double minSupportUse = rowNum * minSupport;
		int count_j = 0;
		int count_i = 0;
		String data = null;
		for (int i = count_i; i < dataFrame.size();) {
			for (int j = count_j; j < dataFrame.get(i).size();) {
				data = dataFrame.get(i).get(j);
				break;
			}
			if (!everyData.contains(data)) {
				everyData.add(data);
				int freqNum = 0;
				for (int k = 0; k < dataFrame.size(); k++) {
					if (dataFrame.get(k).contains(data)) {
						freqNum++;
					}
				}

				List<String> dataList = new ArrayList<>();
				dataList.add(data);
				if (freqNum >= minSupportUse) {
					freqMap.put(dataList, freqNum);
				}
			}
			count_j++;
			if (count_j == dataFrame.get(i).size()) {
				i++;
				count_j = 0;
			}
		}

	
		List<List<String>> secondFreq = new ArrayList<>();

		return freqMap;
	}

}

/**
	 * 根据文件获得数据集
	 * @param path
	 * @return
	 * @throws Exception 
	 */
	public static List<List<String>> dataFrame(String path) throws Exception {
		File file = new File(path);
		if(!file.exists()) {
			throw new Exception("文件不存在!");
		}
		FileReader read = new FileReader(file);
		BufferedReader reader = new BufferedReader(read);
		
		String line = null;
		
		List<List<String>> datas  = new ArrayList<>();
		
		while((line=reader.readLine())!=null) {
			String[] row = null;
			List<String> rowData = new ArrayList<>();
		
			row =line.split(" ");
			for (int j = 0; j < row.length; j++) {
					rowData.add(row[j]);
				}
			datas.add(rowData);
			
		}
		reader.close();
		return datas;
	}
	
	/**
	 * 获得文件行数
	 * @param path
	 * @return
	 * @throws Exception
	 */
	public static Integer getLine(String path) throws Exception {
		File file = new File(path);
		if(!file.exists()) {
			throw new Exception("文件不存在!");
		}
		FileReader read = new FileReader(file);
		LineNumberReader reader = new LineNumberReader(read);
		reader.skip(Long.MAX_VALUE);
		
		int line = reader.getLineNumber()+1;
		return line;
		
	}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值