集合 求交集 求差集 高效 算法 性能对比 与 分析 JAVA

本文介绍了在Java中求解两个集合交集和差集的不同方法,包括遍历、使用List的内置方法、哈希结构优化等,并进行了性能测试对比。分析了各方法的时间复杂度和稳定性,强调了哈希结构在提高性能方面的优势。
摘要由CSDN通过智能技术生成

问题:

两个集合,分别含有一定数量的元素,如何快速得到两个集合的交集和差集?

举例:

给定两个集合List<String> list1和List<String> list2,假定两个集合分别具有m和n个元素,要求取出两个集合中不同的元素,A比B多的元素和B比A多的元素。

说明:

1.以String作为集合中元素的类型,如果是自定义的数据结构,需要重写equals方法

2.输入参数:第一个集合list1,第二个集合list2

3.方法:实现求两个集合交集的方法 & 实现求第一个输入参数集合与第二个输入参数集合的正向差(list1-list2)【即第一个集合减去出现在第二个集合中的元素,相当于list1-commonList】

4.输出参数:交集/差集的集合结果

实现:

方法一:遍历两个集合

实现求 list1-list2:

    private static List<String> getDiffrent(List<String> list1, List<String> list2) {
        List<String> different = new ArrayList<String>();
        for (String str : list1) {
            if (!list2.contains(str)) {
                different.add(str);
            }
        }
        return different;
    }

实现求交集:

    private static List<String> getCommon(List<String> list1, List<String> list2) {
        List<String> common = new ArrayList<String>();
        for (String str : list1) {
            if (list2.contains(str)) {
                common.add(str);
            }
        }
        return common;
    }

分析:利用 List 自带的 contains 方法逐个判断另一个集合中的元素是否属于这个集合,总共要循环的次数是两个 List 的元素数量相乘的积,因此时间复杂度为 O(m*n) ~ O(n^2) ,空间复杂度为 O(m+n) ~ O(1) 。 

ArrayList 的 contains 方法使用遍历的方法进行判断,源码如下:


    /**
     * Returns <tt>true</tt> if this list contains the specified element.
     * More formally, returns <tt>true</tt> if and only if this list contains
     * at least one element <tt>e</tt> such that
     * <tt>(o==null&nbsp;?&nbsp;e==null&nbsp;:&nbsp;o.equals(e))</tt>.
     *
     * @param o element whose presence in this list is to be tested
     * @return <tt>true</tt> if this list contains the specified element
     */
    public boolean contains(Object o) {
        return indexOf(o) >= 0;
    }

    /**
     * Returns the index of the first occurrence of the specified element
     * in this list, or -1 if this list does not contain the element.
     * More formally, returns the lowest index <tt>i</tt> such that
     * <tt>(o==null&nbsp;?&nbsp;get(i)==null&nbsp;:&nbsp;o.equals(get(i)))</tt>,
     * or -1 if there is no such index.
     */
    public int indexOf(Object o) {
        if (o == null) {
            for (int i = 0; i < size; i++)
                if (elementData[i]==null)
                    return i;
        } else {
            for (int i = 0; i < size; i++)
                if (o.equals(elementData[i]))
                    return i;
        }
        return -1;
    }

其中 elementData 为 ArrayList 中存储元素的数组成员变量,使用 indexOf 方法确定重复元素在 ArrayList 中的位置 i ,通过判断 i 返回一个 Boolean 结果。

方法二:采用 List 提供的 retainAll() 和 removeAll() 方法

retainAll() 方法用于保留 arraylist 中在指定集合中也存在的那些元素,也就是删除指定集合中不存在的那些元素。retainAll() 方法的语法为:

arraylist.retainAll(Collection c);

注:arraylist 是 ArrayList 类的一个对象。

removeAll() 方法用于删除存在于指定集合中的动态数组元素。removeAll() 方法的语法为:

arraylist.removeAll(Collection c);

注:arraylist 是 ArrayList 类的一个对象。

因此,使用上述两个函数,我们可以得到求A-B(差集)和交集的方法:

实现求 list1-list2:

    private static List<String> getDiffrent2(List<String> list1, List<String> list2) {
        List<String> different = new ArrayList<String>(list1);
        different.removeAll(list2);
        return different;
    }

实现求交集: 

    private static List<String> getCommon2(List<String> list1, List<String> list2) {
        List<String> common = new ArrayList<String>(list1);
        common.retainAll(list2);
        return common;
    }

分析:同理,这两个方法函数内部也是使用循环进行比较和处理的,因此总共要循环的次数还是两个 List 的元素数量相乘的积,时间复杂度为 O(m*n) ~ O(n^2) ,空间复杂度为 O(m+n) ~ O(1) 。 

retainAll() 和 removeAll() 方法的实现如下:


    /**
     * Removes from this list all of its elements that are contained in the
     * specified collection.
     *
     * @param c collection containing elements to be removed from this list
     * @return {@code true} if this list changed as a result of the call
     * @throws ClassCastException if the class of an element of this list
     *         is incompatible with the specified collection
     * (<a href="Collection.html#optional-restrictions">optional</a>)
     * @throws NullPointerException if this list contains a null element and the
     *         specified collection does not permit null elements
     * (<a href="Collection.html#optional-restrictions">optional</a>),
     *         or if the specified collection is null
     * @see Collection#contains(Object)
     */
    public boolean removeAll(Collection<?> c) {
        Objects.requireNonNull(c);
        return batchRemove(c, false);
    }

    /**
     * Retains only the elements in this list that are contained in the
     * specified collection.  In other words, removes from this list all
     * of its elements that are not contained in the specified collection.
     *
     * @param c collection containing elements to be retained in this list
     * @return {@code true} if this list changed as a result of the call
     * @throws ClassCastException if the class of an element of this list
     *         is incompatible with the specified collection
     * (<a href="Collection.html#optional-restrictions">optional</a>)
     * @throws NullPointerException if this list contains a null element and the
     *         specified collection does not permit null elements
     * (<a href="Collection.html#optional-restrictions">optional</a>),
     *         or if the specified collection is null
     * @see Collection#contains(Object)
     */
    public boolean retainAll(Collection<?> c) {
        Objects.requireNonNull(c);
        return batchRemove(c, true);
    }

    private boolean batchRemove(Collection<?> c, boolean complement) {
        final Object[] elementData = this.elementData;
        int r = 0, w = 0;
        boolean modified = false;
        try {
            for (; r < size; r++)
                if (c.contains(elementData[r]) == complement)
                    elementData[w++] = elementData[r];
        } finally {
            // Preserve behavioral compatibility with AbstractCollection,
            // even if c.contains() throws.
            if (r != size) {
                System.arraycopy(elementData, r,
                                 elementData, w,
                                 size - r);
                w += size - r;
            }
            if (w != size) {
                // clear to let GC do its work
                for (int i = w; i < size; i++)
                    elementData[i] = null;
                modCount += size - w;
                size = w;
                modified = true;
            }
        }
        return modified;
    }

其中,以下方法用于判断输入参数是否为 NULL 。

Objects.requireNonNull(c);

batchRemove方法用于把输入参数集合C中出现/未出现的元素重新放在集合中。通过一个 boolean 类型的输入参数进行控制,通过比较控制是否在C中出现/未出现,主要流程在try中实现,使用for循环进行遍历。在finally中处理元素的数量,因为这里直接对类的实例(调用方法的对象)进行了修改,因此需要清除多余的数据。

方法三:使用带有 Hash 的数据结构辅助提高性能

上述两种方法最大的性能问题在于查找使用了遍历的方法,因此优化可以分为两个方向:

  • 第一,对于有序的数据,我们就可以使用快速查找的算法,比如设置指针 i 和 j ,分别交替比较一轮,即可得到差集、交集,相应的时间复杂度和空间复杂度是 O(m+n) ~ O(1) 。因此,对于未排序数据可以先对于list1和list2先进行排序,此时使用快速排序等高性能算法,然后执行上述操作。
  • 第二,对于无序数据先排序再比较仍然耗费很多的性能,因为排序不是我们的最终目的,只是一个中转站给我们歇歇脚。因此,对于无序数据在查找上我们可以使用 Hash 结构,通过hash值实现快速查找。

因此我们采用第二种方思路进行实现,以获得更高的性能。 我们可以用一个map存放lsit的所有元素,其中的key为lsit1的各个元素,value为该元素出现的次数,接着使用list2的所有元素逐一和map里的key进行比较,如果已经存在则value更新为0。最终,只要取出map里value为1的元素即可实现 list1-list2;取出map里value为0的元素即可实现求交集。只需循环m+n次,大大减少了循环的次数。

实现求 list1-list2:

    private static List<String> getDiffrent3(List<String> list1, List<String> list2) {
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size());
        List<String> different = new ArrayList<String>();
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
            }
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                different.add(entry.getKey());
            }
        }
        return different;
    }

实现求交集: 

    private static List<String> getCommon3(List<String> list1, List<String> list2) {
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size());
        List<String> common = new ArrayList<String>();
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
            }
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 0) {
                common.add(entry.getKey());
            }
        }
        return common;
    }

这种方式最适合的方式其实是在一个函数方法中判直接得到差集和交集。分别取得map结构中值为0和为1的key集合即可。例如:

    private static void getAll(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> different = new ArrayList<String>(); // list1-list2
        List<String> different2 = new ArrayList<String>(); // list2-list1
        List<String> common = new ArrayList<String>(); // 全集
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
                continue;
            }
            map.put(string, 2);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 0) {
                common.add(entry.getKey());
            } else if (entry.getValue() == 1) {
                different.add(entry.getKey());
            } else if (entry.getValue() == 2) {
                different2.add(entry.getKey());
            }
        }
        System.out.println("getAll total times: " + (System.nanoTime() - st));
    }

其中,common对应value为0的,表示全集,different对应value为1的,表示list1-list2,different2对于value为2的,表示list2-list1。 

方法四:方法三的进一步改进

显然,这种方法大大减少耗时,是方法1的1/4,是方法2的1/40,这个性能的提升时相当可观的,但是,这不是最佳的解决方法,观察方法3我们只是随机取了一个list作为首次添加的标准,这样一旦我们的list2比list1的size大,则我们第二次put时的if判断也会耗时(put方法非常耗时),因此当我们不需要知道是list1-list2还是list2-list1的差集时,或者只需要知道两个差集的合集时,我们可以改为:

    private static List<String> getDiffrent4(List<String> list1, List<String> list2) {
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> diff = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            Integer cc = map.get(string);
            if (cc != null) {
                map.put(string, ++cc);
                continue;
            }
            map.put(string, 1);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                diff.add(entry.getKey());
            }
        }
        return diff;
    }

此时,返回的diff就是二者的差集的合集。

    private static List<String> getCommon4(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> common = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            Integer cc = map.get(string);
            if (cc != null) {
                map.put(string, ++cc);
                continue;
            }
            map.put(string, 1);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 2) {
                common.add(entry.getKey());
            }
        }
        System.out.println("getCommon4 total times " + (System.nanoTime() - st));
        return common;
    }

此时,返回的common就是两个集合的交集。

这种方式减少了循环里的判断,对两个list的大小进行了判断,小的在最后添加,进一步提高了性能。

方法五:方法四的进一步改进

两个List不管有多少个重复,只要重复的元素在两个List都能找到,则不应该包含在返回值里面,所以在做第二次循环时,这样判断:如果当前元素在map中找不到,则肯定需要添加到返回值中,如果能找到则value++,遍历完之后diff里面已经包含了只在list2里而没在list2里的元素,剩下的工作就是找到list1里有list2里没有的元素,遍历map取value为1的即可。

优化点:原先,判断在map中找不到会插进去,最后统一取出所有value为定值的所有key,现在将这其中的一部分数据的存拿出来直接放到结果中,不需要取,最后合并返回即可。

    private static List<String> getDiffrent5(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        List<String> diff = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        Map<String, Integer> map = new HashMap<String, Integer>(maxList.size());
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            if (map.get(string) != null) {
                map.put(string, 2);
                continue;
            }
            diff.add(string);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                diff.add(entry.getKey());
            }
        }
        System.out.println("getDiffrent5 total times " + (System.nanoTime() - st));
        return diff;
    }

这种方式对求交集没有优化,仅对求双差集有性能提高。

性能测试对比

我们使用分别具有一万个数的两个list进行测试,运行时间结果为:

getDiffrent2 total times: 262562400
getDiffrent3 total times: 11714800
getDiffrent4 total time: 8856901
getDiffrent5 total times 4786700


getCommon total times: 253939399
getCommon2 total times: 257528099
getCommon3 total times: 8060700
getCommon4 total times: 7648300


getAll total times: 5642201

可以明显看到性能上的优化。

此外,观察结果,我们可以发现,getDiffrent和getDiffrent2生成的结果是按序的,而getDiffrent345均使用了Hash使得最后的结果是乱序的。同理,getCommon和getCommon2也是按顺序的,getCommon3和getCommon4是不按照顺序的。

测试代码:

import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.ArrayList;


public class Main {
    public static void main(String[] args) {
        List<String> list1 = new ArrayList<String>();
        List<String> list2 = new ArrayList<String>();
        for (int i = 0; i < 10000; i++) {
            list1.add("test" + i);
            list2.add("test" + i * 2);
        }
        List<String> res = getDiffrent(list1, list2);
        List<String> res2 = getDiffrent2(list1, list2);
        List<String> res3 = getDiffrent3(list1, list2);
        List<String> res4 = getDiffrent4(list1, list2);
        List<String> res5 = getDiffrent5(list1, list2);

        List<String> res6 = getCommon(list1, list2);
        List<String> res7 = getCommon2(list1, list2);
        List<String> res8 = getCommon3(list1, list2);
        List<String> res9 = getCommon4(list1, list2);

        getAll(list1, list2);
    }

    private static List<String> getDiffrent5(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        List<String> diff = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        Map<String, Integer> map = new HashMap<String, Integer>(maxList.size());
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            if (map.get(string) != null) {
                map.put(string, 2);
                continue;
            }
            diff.add(string);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                diff.add(entry.getKey());
            }
        }
        System.out.println("getDiffrent5 total times " + (System.nanoTime() - st));
        return diff;
    }

    private static List<String> getDiffrent4(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> diff = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            Integer cc = map.get(string);
            if (cc != null) {
                map.put(string, ++cc);
                continue;
            }
            map.put(string, 1);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                diff.add(entry.getKey());
            }
        }
        System.out.println("getDiffrent4 total time: " + (System.nanoTime() - st));
        return diff;
    }

    private static List<String> getCommon4(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> common = new ArrayList<String>();
        List<String> maxList = list1;
        List<String> minList = list2;
        if (list2.size() > list1.size()) {
            maxList = list2;
            minList = list1;
        }
        for (String string : maxList) {
            map.put(string, 1);
        }
        for (String string : minList) {
            Integer cc = map.get(string);
            if (cc != null) {
                map.put(string, ++cc);
                continue;
            }
            map.put(string, 1);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 2) {
                common.add(entry.getKey());
            }
        }
        System.out.println("getCommon4 total times: " + (System.nanoTime() - st));
        return common;
    }

    private static List<String> getDiffrent3(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size());
        List<String> different = new ArrayList<String>();
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
            }
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 1) {
                different.add(entry.getKey());
            }
        }
        System.out.println("getDiffrent3 total times: " + (System.nanoTime() - st));
        return different;
    }

    private static List<String> getCommon3(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size());
        List<String> common = new ArrayList<String>();
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
            }
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 0) {
                common.add(entry.getKey());
            }
        }
        System.out.println("getCommon3 total times: " + (System.nanoTime() - st));
        return common;
    }

    private static void getAll(List<String> list1, List<String> list2) {
        long st = System.nanoTime();
        Map<String, Integer> map = new HashMap<String, Integer>(list1.size() + list2.size());
        List<String> different = new ArrayList<String>(); // list1-list2
        List<String> different2 = new ArrayList<String>(); // list2-list1
        List<String> common = new ArrayList<String>(); // 全集
        for (String string : list1) {
            map.put(string, 1);
        }
        for (String string : list2) {
            if (map.get(string) != null) {
                map.put(string, 0);
                continue;
            }
            map.put(string, 2);
        }
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (entry.getValue() == 0) {
                common.add(entry.getKey());
            } else if (entry.getValue() == 1) {
                different.add(entry.getKey());
            } else if (entry.getValue() == 2) {
                different2.add(entry.getKey());
            }
        }
        System.out.println("getAll total times: " + (System.nanoTime() - st));
    }

    private static List<String> getDiffrent2(List<String> list1, List<String> list2) {
        long st = System.nanoTime(); // 计时测试
        List<String> different = new ArrayList<String>(list1);
        different.removeAll(list2);
        System.out.println("getDiffrent2 total times: " + (System.nanoTime() - st)); // 输出运行时间
        return different;
    }

    private static List<String> getCommon2(List<String> list1, List<String> list2) {
        long st = System.nanoTime(); // 计时测试
        List<String> common = new ArrayList<String>(list1);
        common.retainAll(list2);
        System.out.println("getCommon2 total times: " + (System.nanoTime() - st)); // 输出运行时间
        return common;
    }

    private static List<String> getDiffrent(List<String> list1, List<String> list2) {
        long st = System.nanoTime(); // 计时测试
        List<String> different = new ArrayList<String>();
        for (String str : list1) {
            if (!list2.contains(str)) {
                different.add(str);
            }
        }
        System.out.println("getDiffrent total times: " + (System.nanoTime() - st)); // 输出运行时间
        return different;
    }

    private static List<String> getCommon(List<String> list1, List<String> list2) {
        long st = System.nanoTime(); // 计时测试
        List<String> common = new ArrayList<String>();
        for (String str : list1) {
            if (list2.contains(str)) {
                common.add(str);
            }
        }
        System.out.println("getCommon total times: " + (System.nanoTime() - st)); // 输出运行时间
        return common;
    }
}

分析:

  • 方法1-5性能逐步提高(主要是时间复杂度);
  • 方法1和方法2使用了逐一比较,因此结果也是按顺序的,稳定的;方法3和方法4和方法5使用了Hash的方式,因此结果是随机的,不稳定,在具体使用中需要注意场景(可以选择加上快排);
  • 方法2是最简单的,代码量最少;
  • 方法1使用contains,该法可以在同时求交集、差集和全集这三个集合的时候只需要判断一次,因此如果同时求可以提高一些性能,但方法2直接调用,灵活性低;
  • 同样,同时求交集和差集使用getAll()方法能够提高性能,比单独执行getDiffrent()和getCommon()的时间加起来要少,但封装需要考虑同时返回两个集合的情况。

 其他:

在求得了交集和差集之后,可以很方便地得到全集。全集、交集和差集:这三者使用相同的方法下,可以知二得三。全集=交集+差集(2);交集=全集-差集;差集=全集-交集。

因此,相应地,使用下一篇文章(集合 求全集 高效 算法 性能对比 与 分析 JAVA)中求全集的方法再加上求交集/差集中的一个,也可以再求得另外一个。但是这种方式会有较高的时间和空间成本。

参考:

https://www.cnblogs.com/czpblog/archive/2012/08/06/2625794.html
https://blog.csdn.net/lixianrich/article/details/103822214
https://blog.csdn.net/sinat_21843047/article/details/78783681

以上就是关于集合操作的总结与性能分析,如果各位有其他方法,欢迎讨论交流并在评论区留言,文章将及时更新。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值