话不多说,直接上代码
package com.vickllny.leetcodelearn;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Array;
import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
/**
* 归并排序
* @author vickllny
* @date 2022-02-11 22:31:20
*/
@Slf4j
public class MergeSortTask<T extends Comparable<T>> extends RecursiveTask<T[]> {
private static final ForkJoinPool FORK_JOIN_POOL = new ForkJoinPool();
/**
* 小数组的最大长度
*/
private static final int MIN = 2;
private final T[] data;
public MergeSortTask(T[] data) {
Objects.requireNonNull(data, "data cannot be null");
this.data = data;
}
public T[] sort(){
return FORK_JOIN_POOL.invoke(this);
}
private T[] mergeSort(){
return mergeSort(this.data);
}
/**
* 归并排序
* PS:将数组分割成最大长度为2的数组
* @param data
* @param <T>
* @return
*/
public static <T extends Comparable<T>> T[] mergeSort(T[] data){
if(data.length < MIN){
return data;
}
if(data.length == MIN){
//交换位置
T v1 = data[0];
T v2 = data[1];
if(v1.compareTo(v2) > 0) {
data[0] = v2;
data[1] = v1;
}
return data;
}
BiArray<T> biArray = separateArray(data);
final MergeSortTask<T> task1 = new MergeSortTask<>(biArray.getData1());
task1.fork();
final MergeSortTask<T> task2 = new MergeSortTask<>(biArray.getData2());
task2.fork();
return doMergeSort(task1.join(), task2.join());
}
@Override
protected T[] compute() {
return mergeSort();
}
/**
* 真正执行排序
* @param data1
* @param data2
* @param <T>
* @return
*/
private static <T extends Comparable<T>> T[] doMergeSort(T[] data1, T[] data2) {
int length1 = data1.length;
int length2 = data2.length;
int max = length1 + length2;
T[] data = (T[])Array.newInstance(data1[0].getClass(), max);
for(int index = 0,i = 0,j = 0; index < max; index++){
T value,value1 = null,value2 = null;
if(i != length1){
value1 = data1[i];
}
if(j != length2){
value2 = data2[j];
}
if(value1 != null && value2 != null){
final int compare = value1.compareTo(value2);
if(compare < 0){
value = value1;
i++;
}else if(compare > 0){
value = value2;
j++;
}else {
value = value2;
i++;
j++;
data[index] = value;
data[index++] = value;
continue;
}
}else if(value1 != null){
value = value1;
i++;
}else if(value2 != null){
value = value2;
j++;
}else {
break;
}
data[index] = value;
}
return data;
}
/**
* 一个数组分割为2个数组
* @param data
* @param <T>
* @return
*/
private static <T extends Comparable<T>> BiArray<T> separateArray(T[] data){
int length = data.length;
int tempLength = length / 2;
T[] data1 = (T[])Array.newInstance(data[0].getClass(), tempLength);
System.arraycopy(data, 0, data1, 0, tempLength);
T[] data2 = (T[])Array.newInstance(data[0].getClass(), length - tempLength);
System.arraycopy(data, tempLength, data2, 0, length - tempLength);
return new BiArray<T>(data1, data2);
}
/**
* 二位数组临时对象,因为范型构建二位数组比较麻烦
* @param <T>
*/
private static class BiArray<T extends Comparable<T>> {
private final T[] data1;
private final T[] data2;
public BiArray(T[] data1, T[] data2){
this.data1 = data1;
this.data2 = data2;
}
public T[] getData1() {
return data1;
}
public T[] getData2() {
return data2;
}
}
/**
* 测试
* @param args
*/
public static void main(String[] args) {
int length = 1000000;
Integer[] data = new Integer[length];
Set<Integer> set = new HashSet<>();
final Random random = new Random();
for (int i = 0; i < length; i++) {
while (true){
//生成不重复的随机数
Integer val = random.nextInt(length * 1000);
if(set.contains(val)){
continue;
}
set.add(val);
data[i] = val;
break;
}
}
final MergeSortTask<Integer> mergeSort = new MergeSortTask<>(data);
final long start = System.currentTimeMillis();
final Integer[] sort = mergeSort.mergeSort();
final long end = System.currentTimeMillis();
System.out.println((end - start) + "ms");
}
}