【180】Java用堆实现从列表中获取第k小(或大)的元素

24 篇文章 0 订阅

本文讨论的是在不改变用户输入的列表的前提下,按照用户输入的顺序,输出第k个元素。其中k是从0开始计算。

1. 用堆的方法

下面的代码GetKthByHeapUtils.java,用户可以用自定义的排序规则,获取排序中第k个元素。

设计思路:我按照从小到大的排序为例子做讲解。先创建新的列表,容量是 k + 1,取名heap。把列表中的前 k + 1 个元素放入列表heap中,构造最大堆(如果从大到小就是最小堆)。遍历列表中剩余的元素,每个元素和堆顶做对比。如果小于堆顶就和堆顶交换位置,并且调整堆结构;反之就不做操作,直接比对下一个元素。最后直接返回堆顶即可。堆顶就是第k小的元素。

GetKthByHeapUtils.java

package zhangchao.getk;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

/**
 * 根据用户指定的排序规则,获取第k个元素。k从0开始计算。
 * @author zhangchao
 */
public class GetKthByHeapUtils {

    /**
     * 创建堆
     * @param list 要进行排序的列表
     * @param listSize 列表长度
     * @param comparator 比较用的函数钩子
     * @param <T> list中的元素类型
     */
    private static<T> void createHeap(List<T> list, int listSize, Comparator<T> comparator) {
        // 假设第0个元素已经是堆了,从第1个元素开始加入堆。
        for (int i = 1; i < listSize; i++) {
            int newIndex = i;
            while (newIndex > 0) {
//              int parentIndex = (newIndex - 1) / 2;
                int parentIndex = (newIndex - 1) >> 1;
                T parent = list.get(parentIndex);
                T newNode = list.get(newIndex);
                if (comparator.compare(newNode, parent) > 0) {
                    list.set(parentIndex, newNode);
                    list.set(newIndex, parent);
                    newIndex = parentIndex;
                } else {
                    // 小于等于父亲节点,没有上升的需要,不需要再查找上级节点了。
                    newIndex = -1;
                }
            }
        }
    }

    /**
     * 从列表中获取,从小到大排序,第k个元素。
     * 利用堆来保存前面 K + 1 个元素,并且是最大堆。后面的元素只要小于堆顶元素,就和堆顶元素交换位置,
     * 然后调整堆的结构。
     * @param list 列表
     * @param k 第k个元素,k从0开始计算。
     * @param comparator 比较的函数钩子。
     * @param <T> 类型。
     * @return 从小到大排序,第k个元素。
     */
    public static<T> T getKth(List<T> list, int k, Comparator<T> comparator) {
        if (null == list || list.isEmpty()) {
            throw new RuntimeException("List is empty!");
        }
        if (k < 0) {
            throw new RuntimeException("K must be greater than or equal to 0 !");
        }
        final int size = list.size();
        if (k >= size) {
            throw new RuntimeException("K must be less than the size of list !");
        }
        if (0 == k) {
            T min = list.get(0);
            for (int i = 0; i < size; i++) {
                T t = list.get(i);
                if (comparator.compare(t, min) < 0) {
                    min = t;
                }
            }
            return min;
        }
        if ((size - 1) == k) {
            T max = list.get(0);
            for (int i = 0; i < size; i++) {
                T t = list.get(i);
                if (comparator.compare(t, max) > 0) {
                    max = t;
                }
            }
            return max;
        }
        // 堆的长度
        int heapLength = k + 1;
        List<T> heap = new ArrayList<>(heapLength);
        for (int i = 0; i < heapLength; i++) {
            heap.add(list.get(i));
        }
        // 创建堆
        createHeap(heap, heapLength, comparator);

        // 从第k+1个元素开始,每个元素和堆顶比较。如果小于堆顶,就和堆顶交换位置,
        // 然后调整堆的结构。
        for (int i = heapLength; i < size; i++) {
            T current = list.get(i);
            if (comparator.compare(current, heap.get(0)) < 0) {
                heap.set(0, current);
                int currentIndex = 0;
                boolean whileFlag = true;
                while(whileFlag) {
                    int leftIndex = (currentIndex << 1) + 1;
                    int rightIndex = (currentIndex << 1) + 2;
                    if (rightIndex < heapLength) {
                        T left = heap.get(leftIndex);
                        T right = heap.get(rightIndex);
                        int maxIndex = rightIndex;
                        T max = right;
                        if (comparator.compare(left, right) > 0) {
                            maxIndex = leftIndex;
                            max = left;
                        }
                        if (comparator.compare(max, current) > 0) {
                            heap.set(currentIndex, max);
                            heap.set(maxIndex, current);
                            currentIndex = maxIndex;
                        } else {
                            whileFlag = false;
                        }
                    } else if (leftIndex < heapLength) {
                        T left = heap.get(leftIndex);
                        if (comparator.compare(left, current) > 0) {
                            heap.set(currentIndex, left);
                            heap.set(leftIndex, current);
                            currentIndex = leftIndex;
                        } else {
                            whileFlag = false;
                        }
                    } else {
                        whileFlag = false;
                    }
                }
            }
        }

        return heap.get(0);
    }
}

2. 其他方法

下面是用了另外两个方法来实现功能:

GetKthByListUtils

package zhangchao.getk;

import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;

/**
 *
 * @author zhangchao
 */
public class GetKthByListUtils {
    /**
     * 复制列表,然后整个列表排序,返回第k个元素。
     * @param originList 列表
     * @param k 第k个元素,k从0开始计算。
     * @param comparator 比较的函数钩子。
     * @param <T> 类型。
     * @return 从小到大排序,第k个元素。
     */
    public static<T> T getKth_sortAll(List<T> originList, int k, Comparator<T> comparator) {
        List<T> list = new ArrayList<>();
        for (T t : originList) {
            list.add(t);
        }
        list.sort(comparator);
        return list.get(k);
    }

    /**
     * 前面k+1个元素组成小列表smallList,排序。后面的元素和小列表最后一个元素比较。如果小于smallList最后一个元素,
     * 交换位置,重新对smallList排序。
     * @param originList 列表
     * @param k 第k个元素,k从0开始计算。
     * @param comparator 比较的函数钩子。
     * @param <T> 类型。
     * @return 从小到大排序,第k个元素。
     */
    public static<T> T getKth_smallList(final List<T> originList, final int k, Comparator<T> comparator) {
        if (null == originList || originList.isEmpty()) {
            throw new RuntimeException("List is empty!");
        }
        if (k < 0) {
            throw new RuntimeException("K must be greater than or equal to 0 !");
        }
        final int size = originList.size();
        if (k >= size) {
            throw new RuntimeException("K must be less than the size of list !");
        }
        if (0 == k) {
            T min = originList.get(0);
            for (int i = 0; i < size; i++) {
                T t = originList.get(i);
                if (comparator.compare(t, min) < 0) {
                    min = t;
                }
            }
            return min;
        }
        if ((size - 1) == k) {
            T max = originList.get(0);
            for (int i = 0; i < size; i++) {
                T t = originList.get(i);
                if (comparator.compare(t, max) > 0) {
                    max = t;
                }
            }
            return max;
        }

        int smallListSize = k + 1;
        List<T> smallList = new ArrayList<>(smallListSize);
        for (int i = 0; i < smallListSize; i++) {
            smallList.add(originList.get(i));
        }
        smallList.sort(comparator);
        for (int i = smallListSize; i < originList.size(); i++) {
            T t = originList.get(i);
            if (comparator.compare(t, smallList.get(k)) < 0) {
                smallList.set(k, t);
//                smallList.sort(comparator);
                for (int smallIndex = 0; smallIndex < k; smallIndex++) {
                    T smallT = smallList.get(smallIndex);
                    if (comparator.compare(smallT, t) > 0) {
                        smallList.remove(k);
                        smallList.add(smallIndex, t);
                        smallIndex = k; // 结束循环。
                    }
                }
            }
        }
        return smallList.get(k);
    }
}

3. 对比测试

下面是测试代码,统一用了长度为10000的列表做测试。

package zhangchao.getk;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;

public class GetKMain {



    public static void main(String[] args) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < 10000; i++) {
            list.add(i);
        }
        Collections.shuffle(list);
        Comparator<Integer> comparator = ((o1, o2) -> o1 - o2);
        long t1, t2;

        final int k = 9000;

        t1 = System.currentTimeMillis();
        Integer k1 = GetKthByHeapUtils.getKth(list, k, comparator);
        t2 = System.currentTimeMillis();
        System.out.println("heap       k1=" + k1 + " time=" + (t2 - t1));

        t1 = System.currentTimeMillis();
        Integer k2 = GetKthByListUtils.getKth_sortAll(list, k, comparator);
        t2 = System.currentTimeMillis();
        System.out.println("sort all   k2=" + k2 + " time=" + (t2 - t1));

        t1 = System.currentTimeMillis();
        Integer k3 = GetKthByListUtils.getKth_smallList(list, k, comparator);
        t2 = System.currentTimeMillis();
        System.out.println("small list k3=" + k3 + " time=" + (t2 - t1));

    }

}

下面是统计折线图。横轴是k的取值,纵轴是耗时(单位:毫秒)。蓝色、红色、绿色分别代码代码中 heap、sort all、small list 三种代码实现。

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值