通用版归并排序ForkJoin线程池版本

话不多说,直接上代码

package com.vickllny.leetcodelearn;

import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 * 归并排序
 * @author vickllny
 * @date 2022-02-11 22:31:20
 */
@Slf4j
public class MergeSortTask<T extends Comparable<T>> extends RecursiveTask<T[]> {

    private static final ForkJoinPool FORK_JOIN_POOL = new ForkJoinPool();

    /**
     * 小数组的最大长度
     */
    private static final int MIN = 2;

    private final T[] data;

    public MergeSortTask(T[] data) {
        Objects.requireNonNull(data, "data cannot be null");
        this.data = data;
    }

    public T[] sort(){
        return FORK_JOIN_POOL.invoke(this);
    }

    private T[] mergeSort(){
        return mergeSort(this.data);
    }

    /**
     * 归并排序
     * PS:将数组分割成最大长度为2的数组
     * @param data
     * @param <T>
     * @return
     */
    public static <T extends Comparable<T>> T[] mergeSort(T[] data){
        if(data.length < MIN){
            return data;
        }
        if(data.length == MIN){
            //交换位置
            T v1 = data[0];
            T v2 = data[1];
            if(v1.compareTo(v2) > 0) {
                data[0] = v2;
                data[1] = v1;
            }
            return data;
        }
        BiArray<T> biArray = separateArray(data);
        final MergeSortTask<T> task1 = new MergeSortTask<>(biArray.getData1());
        task1.fork();
        final MergeSortTask<T> task2 = new MergeSortTask<>(biArray.getData2());
        task2.fork();
        return doMergeSort(task1.join(), task2.join());
    }

    @Override
    protected T[] compute() {
        return mergeSort();
    }

    /**
     * 真正执行排序
     * @param data1
     * @param data2
     * @param <T>
     * @return
     */
    private static <T extends Comparable<T>> T[] doMergeSort(T[] data1, T[] data2) {
        int length1 = data1.length;
        int length2 = data2.length;
        int max = length1 + length2;
        T[] data = (T[])Array.newInstance(data1[0].getClass(), max);
        for(int index = 0,i = 0,j = 0; index < max; index++){
            T value,value1 = null,value2 = null;
            if(i != length1){
                value1 = data1[i];
            }
            if(j != length2){
                value2 = data2[j];
            }
            if(value1 != null && value2 != null){
                final int compare = value1.compareTo(value2);
                if(compare < 0){
                    value = value1;
                    i++;
                }else if(compare > 0){
                    value = value2;
                    j++;
                }else {
                    value = value2;
                    i++;
                    j++;
                    data[index] = value;
                    data[index++] = value;
                    continue;
                }
            }else if(value1 != null){
                value = value1;
                i++;
            }else if(value2 != null){
                value = value2;
                j++;
            }else {
                break;
            }
            data[index] = value;
        }
        return data;
    }

    /**
     * 一个数组分割为2个数组
     * @param data
     * @param <T>
     * @return
     */
    private static <T extends Comparable<T>> BiArray<T> separateArray(T[] data){
        int length = data.length;
        int tempLength = length / 2;
        T[] data1 = (T[])Array.newInstance(data[0].getClass(), tempLength);
        System.arraycopy(data, 0, data1, 0, tempLength);
        T[] data2 = (T[])Array.newInstance(data[0].getClass(), length - tempLength);
        System.arraycopy(data, tempLength, data2, 0, length - tempLength);
        return new BiArray<T>(data1, data2);
    }

    /**
     * 二位数组临时对象,因为范型构建二位数组比较麻烦
     * @param <T>
     */
    private static class BiArray<T extends Comparable<T>> {
        private final T[] data1;
        private final T[] data2;
        public BiArray(T[] data1, T[] data2){
            this.data1 = data1;
            this.data2 = data2;
        }
        public T[] getData1() {
            return data1;
        }
        public T[] getData2() {
            return data2;
        }
    }


    /**
     * 测试
     * @param args
     */
    public static void main(String[] args) {
        int length = 1000000;
        Integer[] data = new Integer[length];
        Set<Integer> set = new HashSet<>();
        final Random random = new Random();
        for (int i = 0; i < length; i++) {
            while (true){
                //生成不重复的随机数
                Integer val = random.nextInt(length * 1000);
                if(set.contains(val)){
                    continue;
                }
                set.add(val);
                data[i] = val;
                break;
            }
        }
        final MergeSortTask<Integer> mergeSort = new MergeSortTask<>(data);
        final long start = System.currentTimeMillis();
        final Integer[] sort = mergeSort.mergeSort();
        final long end = System.currentTimeMillis();
        System.out.println((end - start) + "ms");
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值