数据挖掘 Apriori算法的Java代码实现

简单说明

学院开了一门课《数据挖掘与机器学习》,要求我们计算机1、2两个班的全部同学选修这门课,包括课程实验。教材采用王振武、徐慧编著的《数据挖掘算法原理与实现》。教材里面提供的代码是C++代码,而由于本人更习惯使用Java语言编程,为了深入理解算法原理和过程,完成实验任务,于是用Java语言实现了Apriori关联规则挖掘算法。

Apriori算法

Apriori算法的基本思想是通过对数据库的多次扫描来计算项集的支持度,发现所有的频繁项集从而生成关联规则。

其实就是从一堆数据里面找出出现次数最多的数据组合,找出来的组合就是强关联的。

产生频繁项集的过程包括连接和剪枝两步。

连接步:
假设有两个有序3-项集L1 = (A, B, C),L2 = (A, B, D)。则L1和L2可连接产生4-项集C1 = (A, B, C, D)。
剪枝步:
频繁k-项集的任何自己必须是频繁项集,根据这个性质去除连接步产生的不满足支持度的k-项集。

代码如下:

//Item.java

import java.util.ArrayList;

/**
 * 项集
 */
@SuppressWarnings("hiding")
public class Item<String> extends ArrayList<String> {

    private static final long serialVersionUID = 1L;

    /**
     * 判断本项集与next项集是否可连接
     * 
     * @param next
     * @return
     */
    public boolean linkable(Item<String> next) {
        if (this.size() != next.size())
            return false;
        for (int i = 0; i < this.size() - 1; i++) {
            if (!get(i).equals(next.get(i)))
                return false;
        }
        return true;
    }

    /**
     * 对项集去重
     */
    public void unique() {
        String s = get(0);
        for (int i = 1; i < size(); i++) {
            String t = get(i);
            while (t.equals(s)) {
                remove(t);
                if (i < size())
                    t = get(i);
                else {
                    break;
                }
            }
            s = t;
        }
    }

}
//Apriori.java

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;

/**
 * 算法实体
 */
public class Apriori {

    private HashMap<String, Integer> oneElementSet; // 一项集

    private ArrayList<Item<String>> sourceItems; // 原始数据

    private ArrayList<HashMap<Item<String>, Integer>> rankFrequentSets; // 各级频繁项集

    private int minValue; // 最小阈值

    Apriori(int size, int minValue) {
        oneElementSet = new HashMap<>();
        sourceItems = new Item<>();
        rankFrequentSets = new Item<>();
        this.minValue = minValue;
    }

    /**
     * 添加项集
     * 
     * @param item
     */
    public void addItem(Item<String> item) {
        // 对项集排序后添加
        item.sort(new Comparator<String>() {
            @Override
            public int compare(String arg0, String arg1) {
                return arg0.compareTo(arg1);
            }
        });
        sourceItems.add(item);
    }

    public ArrayList<HashMap<Item<String>, Integer>> getRankFrequentSets() {
        return rankFrequentSets;
    }

    /**
     * 找出一项集
     * 
     * @return
     */
    public HashMap<String, Integer> findOneElementItems() {
        for (Item<String> list : sourceItems) {
            for (String s : list) {
                if (!oneElementSet.containsKey(s)) {
                    oneElementSet.put(s, 1);
                } else {
                    oneElementSet.put(s, oneElementSet.get(s) + 1);
                }
            }
        }
        return oneElementSet;
    }

    /**
     * 产生频繁一项集
     * 
     * @return
     */
    public HashMap<Item<String>, Integer> obtainFrequentOneElementSet() {
        HashMap<Item<String>, Integer> map = new HashMap<>();
        for (String key : oneElementSet.keySet()) {
            int value = oneElementSet.get(key);
            if (value >= minValue) {
                Item<String> item = new Item<>();
                item.add(key);
                map.put(item, value);
            }
        }
        rankFrequentSets.add(0, map);
        return map;
    }

    /**
     * 产生频繁K项集 剪枝步
     * 
     * @param k
     * @return
     */
    public HashMap<Item<String>, Integer> obtainFrequentSet(int k) {
        Item<Item<String>> items = link(k);
        HashMap<Item<String>, Integer> freSet = new HashMap<>();
        for (Item<String> item : items) {
            int count = 0;
            for (Item<String> source : sourceItems) {
                boolean flag = true;
                for (String s : item) {
                    if (!source.contains(s)) {
                        flag = false;
                        break;
                    }
                }
                if (flag) {
                    count++;
                }
            }
            if (count >= minValue) {
                freSet.put(item, count);
            }
        }
        if (freSet.size() <= 0)
            return null;
        rankFrequentSets.add(k - 1, freSet);
        return freSet;
    }

    /**
     * 连接产生K项集
     * 
     * @param k
     * @return
     */
    public Item<Item<String>> link(int k) {
        Item<Item<String>> items = new Item<>();
        HashMap<Item<String>, Integer> map = rankFrequentSets.get(k - 2);
        Set<Item<String>> keys = map.keySet();
        Iterator<Item<String>> iterator = keys.iterator();
        if (k == 2) {
            for (int i = 0; i < keys.size(); i++) {
                Item<String> item = iterator.next();
                Iterator<Item<String>> iterator2 = keys.iterator();
                for (int j = 0; j < i + 1; j++) {
                    iterator2.next();
                }
                for (int j = i + 1; j < keys.size(); j++) {
                    Item<String> item2 = iterator2.next();
                    Item<String> instance = new Item<>();
                    instance.add(item.get(0));
                    instance.add(item2.get(0));
                    items.add(instance);
                }
            }
            return items;
        } else {
            for (int i = 0; i < keys.size() - 1; i++) {
                Item<String> item = iterator.next();
                Iterator<Item<String>> iterator2 = keys.iterator();
                for (int j = 0; j < i + 1; j++) {
                    iterator2.next();
                }
                for (int j = i + 1; j < keys.size(); j++) {
                    Item<String> item2 = iterator2.next();
                    if (item.linkable(item2)) {
                        Item<String> instance = new Item<>();
                        for (int n = 0; n < k - 1; n++) {
                            instance.add(item.get(n));
                        }
                        instance.add(item2.get(k - 2));
                        items.add(instance);
                    }
                }
            }
        }
        return items;
    }

}
//Main.java

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        int size;
        int minValue;
        Scanner scanner = new Scanner(System.in);
        System.out.print("事务数:");
        size = scanner.nextInt();
        System.out.print("最小阈值:");
        minValue = scanner.nextInt();
        Apriori apriori = new Apriori(size, minValue);
        scanner.nextLine();
        for (int i = 0; i < size; i++) {
            Item<String> item = new Item<>();
            System.out.print("输入第" + (i + 1) + "项:");
            String line = scanner.nextLine();
            Scanner scanner2 = new Scanner(line);
            while (scanner2.hasNext()) {
                item.add(scanner2.next());
            }
            scanner2.close();
            item.unique();//对输入的项集去重
            apriori.addItem(item);
        }
        scanner.close();

        HashMap<String, Integer> oneElementSet = apriori.findOneElementItems();
        Iterator<String> iterator = oneElementSet.keySet().iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            System.out.println(key + ":" + oneElementSet.get(key));
        }

        apriori.obtainFrequentOneElementSet();
        int k = 2;
        while (apriori.obtainFrequentSet(k++) != null)
            ;
        ArrayList<HashMap<Item<String>, Integer>> rankSets = apriori.getRankFrequentSets();
        Item<String> item = null;
        HashMap<Item<String>, Integer> map = null;
        for (int i = 0; i < k - 2; i++) {
            map = rankSets.get(i);
            System.out.println("第 " + (i + 1) + " 级频繁项集:");
            Iterator<Item<String>> iterator2 = map.keySet().iterator();
            while (iterator2.hasNext()) {
                item = iterator2.next();
                System.out.print("{ ");
                for (String s : item) {
                    System.out.print(s + " ");
                }
                System.out.print("}\t");
                System.out.println(map.get(item));
            }
        }
        System.out.println("最终频繁项集:");
        Iterator<Item<String>> iterator2 = map.keySet().iterator();
        while (iterator2.hasNext()) {
            item = iterator2.next();
            System.out.print("{ ");
            for (String s : item) {
                System.out.print(s + " ");
            }
            System.out.print("}\t");
            System.out.println(map.get(item));
        }
    }

}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值