package com.sduept.bigdata.ml.apriori.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Predicate;
import com.sduept.bigdata.ml.apriori.Apriori;
import com.sduept.ice.utils.ConsoleTable;
import com.sduept.ice.utils.TransformDataUtils;
/**
* 关联分析
*
* @author xuqinwen
*/
public class AprioriModel {
private double minSupport;
private double minConfidence;
private String path;
public static HashMap<List<String>, Integer> finlFreq = new HashMap<List<String>, Integer>();
public static List<List<String>> associationRules = new ArrayList<>();
/**
* 展示频繁项集
*/
public static void freqItemsetsShow() {
ConsoleTable table = new ConsoleTable(2, true);
table.appendRow();
table.appendColum("items").appendColum("freq");
for(Map.Entry<List<String>, Integer> entry:finlFreq.entrySet()) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < entry.getKey().size(); i++) {
if(i==entry.getKey().size()-1) {
builder.append(entry.getKey().get(i));
}else {
builder.append(entry.getKey().get(i)+",");
}
}
table.appendRow();
table.appendColum(builder.toString()).appendColum(entry.getValue().toString());
}
System.out.println(table.toString());
}
/**
* 关联规则展示
*/
public static void associationRulesShow() {
ConsoleTable table = new ConsoleTable(4, true);
table.appendRow();
table.appendColum("antecedent").appendColum("consequent").appendColum("condfidence").appendColum("lift");
for (int i=0;i<associationRules.size();i++) {
StringBuilder builder = new StringBuilder();
int liftSize = associationRules.get(i).size()-1;
int confidenceSize = associationRules.get(i).size()-2;
int consequentSize = associationRules.get(i).size()-3;
for (int j =associationRules.get(i).size()-4; j >=0; j--) {
if(j==0) {
builder.append(associationRules.get(i).get(j));
}else {
builder.append(associationRules.get(i).get(j)+",");
}
}
table.appendRow();
table.appendColum(builder.toString()).appendColum(associationRules.get(i).get(consequentSize)).appendColum(associationRules.get(i).get(confidenceSize)).appendColum(associationRules.get(i).get(liftSize));
}
System.out.println(table.toString());
}
public AprioriModel(double minSupport, double minConfidence, String path) throws Exception {
List<List<String>> dataFrame = getDataFrame(path);
HashMap<List<String>, Integer> freqItemsets = freqItemsets(dataFrame, minSupport, path);
associationRules(freqItemsets, dataFrame, minConfidence, minSupport, path);
}
//后续废弃掉这个方法
public static HashMap<List<String>, Integer> getFinlFreq(String path, List<List<String>> dataFrame,
double minSupport, String keys) throws Exception {
freqItemsets(dataFrame, minSupport, path);
return finlFreq;
}
private List<List<String>> getDataFrame(String path) throws Exception {
List<List<String>> dataFrame = TransformDataUtils.dataFrame(path);
return dataFrame;
}
// 频繁项集的获取
public static HashMap<List<String>, Integer> freqItemsets(List<List<String>> dataFrame, double minSupport,
String path) throws Exception {
HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
// List<List<String>> dataFrame1 = TransformDataUtils.dataFrame(path);
// HashMap<String, Integer> freqMap = new HashMap<String, Integer>();//
// 存放符合条件的频繁项集 key为项 value为数目
int rowNum = TransformDataUtils.getLine(path);
List<String> everyData = new ArrayList<>();// 放每一个项 使得对项计数时不要重复
double minSupportUse = rowNum * minSupport;// 下面的判断条件 提取出满足条件的
int count_j = 0;
int count_i = 0;
String data = null;
for (int i = count_i; i < dataFrame.size();) {
for (int j = count_j; j < dataFrame.get(i).size();) {
data = dataFrame.get(i).get(j);
break;
}
if (!everyData.contains(data)) {
everyData.add(data);
int freqNum = 0;
for (int k = 0; k < dataFrame.size(); k++) {
if (dataFrame.get(k).contains(data)) {
freqNum++;
}
}
List<String> dataList = new ArrayList<>();
dataList.add(data);
if (freqNum >= minSupportUse) {
finlFreq.put(dataList, freqNum);
}
}
count_j++;
if (count_j == dataFrame.get(i).size()) {
i++;
count_j = 0;
}
}
if (finlFreq.size() >= 2) {
HashMap<List<String>, Integer> freqMap = groupFreq(finlFreq, dataFrame, minSupportUse);
finlFreq.putAll(freqMap);
if (freqMap.size() >= 2) {
iteratorFreq(freqMap, dataFrame, minSupportUse);
}
// freqMaps.putAll(thirdMap);
}
return finlFreq;
}
/**
* 对每一次的频繁项集进行组合 再检索
*
* @param freqMap 第一阶段的频繁项集
* @param dataFrame 原数据集
* @param minSupportUse 计算过的最小支持度
* @throws Exception
*/
public static HashMap<List<String>, Integer> groupFreq(HashMap<List<String>, Integer> freqMap,
List<List<String>> dataFrame, double minSupportUse) throws Exception {
HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
List<List<String>> nextFreq = new ArrayList<>(); // 放入下一阶段的项,由上一阶段频繁项集形成
List<String> freq = new ArrayList<>();
for (Entry<List<String>, Integer> entry : freqMap.entrySet()) {
freq.add(entry.getKey().get(0));
}
int count_i = 0;
int count_j = count_i + 1;
for (int i = count_i; i < freq.size() - 1;) {
for (int j = count_j; j < freq.size();) {
List<String> groupFreq = new ArrayList<>();// 放入组合之后的项
groupFreq.add(freq.get(i));
groupFreq.add(freq.get(j));
nextFreq.add(groupFreq);
break;
}
count_j++;
if (count_j == freq.size()) {
i++;
count_j = i + 1;
}
}
for (int i = 0; i < nextFreq.size(); i++) {
int freqNum = 0;// 用于计数
List<String> everyData = new ArrayList<>();
everyData = nextFreq.get(i);
for (int j = 0; j < dataFrame.size(); j++) {
if (dataFrame.get(j).containsAll(everyData)) {
freqNum++;
}
}
if (freqNum >= minSupportUse) {
freqMaps.put(everyData, freqNum);
}
}
return freqMaps;
}
/**
*
* @param freqMap 第二步的频繁项集
* @param dataFrame 原数据集
* @param minSupportUse 计算过的最小支持度
* @return
*/
public static void iteratorFreq(HashMap<List<String>, Integer> freqMap, List<List<String>> dataFrame,
double minSupportUse) {
List<List<String>> nextFreq = new ArrayList<>();
for (Map.Entry<List<String>, Integer> entry : freqMap.entrySet()) {
nextFreq.add(entry.getKey());
}
List<List<String>> thirdFreq = new ArrayList<>();// 放入第二阶段组合成的频繁项集
int count_i = 0;
int count_j = count_i + 1;
for (int i = count_i; i < nextFreq.size() - 1;) {
Set<String> freqSet = new HashSet<>();// 经过set去重
for (int j = count_j; j < nextFreq.size();) {
for (int j2 = 0; j2 < nextFreq.get(i).size(); j2++) {
freqSet.add(nextFreq.get(i).get(j2));
}
for (int j2 = 0; j2 < nextFreq.get(j).size(); j2++) {
freqSet.add(nextFreq.get(j).get(j2));
}
List<String> freq = new ArrayList<>();
freq.addAll(freqSet);
thirdFreq.add(freq);// 组合后的有可能还是有重的 应该继续使用set
break;
}
count_j++;
if (count_j == nextFreq.size()) {
i++;
count_j = i + 1;
}
}
HashMap<List<String>, Integer> freqMaps = new HashMap<List<String>, Integer>();
for (int i = 0; i < thirdFreq.size(); i++) {
int freqNum = 0;// 用于计数
List<String> everyData = new ArrayList<>();
everyData = thirdFreq.get(i);
for (int j = 0; j < dataFrame.size(); j++) {
if (dataFrame.get(j).containsAll(everyData)) {
freqNum++;
}
}
if (freqNum >= minSupportUse) {
finlFreq.put(everyData, freqNum);
freqMaps.put(everyData, freqNum);
}
}
if (freqMaps.size() >= 2) {
iteratorFreq(freqMaps, dataFrame, minSupportUse);
}
}
/**
* 求出关联规则 置信度和作用度
*
* @param freqItemsets 频繁项集
* @param dataFrame 原数据集
* @param minConfidence 最小置信度
* @param minSupport 最小支持度
* @return
* @throws Exception
*/
public List<List<String>> associationRules(HashMap<List<String>, Integer> freqItemsets,
List<List<String>> dataFrame, double minConfidence, double minSupport, String path) throws Exception {
List<List<String>> allFreq = new ArrayList<>();
List<String> single = new ArrayList<>();
for (Map.Entry<List<String>, Integer> entry : freqItemsets.entrySet()) {
allFreq.add(entry.getKey());
if (entry.getKey().size() == 1) {
single.add(entry.getKey().get(0));
}
}
///List<List<String>> confidenceFreq = new ArrayList<>();
for (int i = 0; i < allFreq.size(); i++) {
List<String> group = new ArrayList<>();
for (int j = 0; j < allFreq.get(i).size(); j++) {
group.add(allFreq.get(i).get(j));
}
double antecedentNum = 0;
// int consequentNum = 0;
int sum = TransformDataUtils.getLine(path);
double minSupportUse = (sum*minSupport);
// int expectConfidenceNum = 0;
for (int j = 0; j < dataFrame.size(); j++) {
if (dataFrame.get(j).containsAll(group)) {
antecedentNum++;
}
}
for (int j = 0; j < single.size(); j++) {// 不满足条件 将其从集合移除 满足条件 加入满足置信度的集合并移除
double consequentNum = 0; // 随之发生的 随之发生的要满足最小支持度
// int sum = TransformDataUtils .getLine(path);//数据的条目
double expectConfidenceNum = 0;// 单项情况发生的次数
if (group.size() == allFreq.get(i).size() + 1) {
group.remove(allFreq.get(i).size());// 移除掉随之发生的单量频繁项集
} else if (group.size() > allFreq.get(i).size() + 1) {
group.remove(allFreq.get(i).size() + 2);// 移除掉支持度
group.remove(allFreq.get(i).size() + 1);// 移除掉置信度
group.remove(allFreq.get(i).size());// 移除掉随之发生的单量频繁项集
}
if (!group.contains(single.get(j))) {// group.get(0).equals(single.get(j)) &&
group.add(single.get(j));
for (int j1 = 0; j1 < dataFrame.size(); j1++) {
if (dataFrame.get(j1).containsAll(group)) {
consequentNum++;
}
if (dataFrame.get(j1).contains(single.get(j))) {
expectConfidenceNum++;
}
}
if (consequentNum >= minSupportUse) {
double confidence = consequentNum / antecedentNum;// 计算出来的置信度
if (confidence >= minConfidence) {
double expectConfidence = expectConfidenceNum / sum;// 期望可信度
double lift = (confidence / expectConfidence);// 作用度
group.add(String.valueOf(confidence));// 将满足条件的置信度加上
group.add(String.valueOf(lift)); // 作用度加上
List<String> groupUse = new ArrayList<>();
groupUse.addAll(group);
associationRules.add(groupUse);
}
}
}
} // single的循环
}
return associationRules;
}
public String[][] lift(String[][] associationRules, String where) {
// TODO Auto-generated method stub
return null;
}
public HashMap<List<String>, Integer> freqItemsets(List<List<String>> dataFrame, double minSupport, String path,String test)
throws Exception {
// List<List<String>> dataFrame1 = TransformDataUtils.dataFrame(path);
HashMap<List<String>, Integer> freqMap = new HashMap<List<String>, Integer>();
int rowNum = TransformDataUtils.getLine(path);
List<String> everyData = new ArrayList<>();// 放每一个项 使得对项计数时不要重复
double minSupportUse = rowNum * minSupport;
int count_j = 0;
int count_i = 0;
String data = null;
for (int i = count_i; i < dataFrame.size();) {
for (int j = count_j; j < dataFrame.get(i).size();) {
data = dataFrame.get(i).get(j);
break;
}
if (!everyData.contains(data)) {
everyData.add(data);
int freqNum = 0;
for (int k = 0; k < dataFrame.size(); k++) {
if (dataFrame.get(k).contains(data)) {
freqNum++;
}
}
List<String> dataList = new ArrayList<>();
dataList.add(data);
if (freqNum >= minSupportUse) {
freqMap.put(dataList, freqNum);
}
}
count_j++;
if (count_j == dataFrame.get(i).size()) {
i++;
count_j = 0;
}
}
List<List<String>> secondFreq = new ArrayList<>();
return freqMap;
}
}
/**
* 根据文件获得数据集
* @param path
* @return
* @throws Exception
*/
public static List<List<String>> dataFrame(String path) throws Exception {
File file = new File(path);
if(!file.exists()) {
throw new Exception("文件不存在!");
}
FileReader read = new FileReader(file);
BufferedReader reader = new BufferedReader(read);
String line = null;
List<List<String>> datas = new ArrayList<>();
while((line=reader.readLine())!=null) {
String[] row = null;
List<String> rowData = new ArrayList<>();
row =line.split(" ");
for (int j = 0; j < row.length; j++) {
rowData.add(row[j]);
}
datas.add(rowData);
}
reader.close();
return datas;
}
/**
* 获得文件行数
* @param path
* @return
* @throws Exception
*/
public static Integer getLine(String path) throws Exception {
File file = new File(path);
if(!file.exists()) {
throw new Exception("文件不存在!");
}
FileReader read = new FileReader(file);
LineNumberReader reader = new LineNumberReader(read);
reader.skip(Long.MAX_VALUE);
int line = reader.getLineNumber()+1;
return line;
}