废话不多说,喜欢直接上代码,我写的有注释,不懂的可以直接问。
一。普通归并排序
import java.util.Random;
/**
* 功能描述 : 普通归并排序
*
* @author Ziyear 2020-05-21 19:38
*/
public class MergeSort {
private static int MAX = 100000;
private static int[] arr = new int[MAX];
static {
Random random = new Random();
for (int i = 1; i <= MAX; i++) {
arr[i - 1] = random.nextInt(MAX);
}
}
public static void main(String[] args) {
long start = System.currentTimeMillis();
sort(arr);
long end = System.currentTimeMillis();
//打印排序用时
System.out.println(end - start);
}
/**
* 功能描述 : 归并排序(每次讲数组分成大小两个部分,并不是排序完成,需要递归)
*
* @param arr 要排序的数组
* @param lo 最左侧下标位置
* @param mid 中间下标位置
* @param hi 最右侧下标位置
* @author Ziyear 2020-5-21 19:47
*/
public static void merge(int[] arr, int lo, int mid, int hi) {
// 第一步,复制数组
int i = lo;
int j = mid + 1;
int[] temp = new int[arr.length];
if (hi - lo >= 0) System.arraycopy(arr, lo, temp, lo, hi - lo);
// 第二步,移动指针进行排序
for (int k = lo; k < hi; k++) {
// 左边用尽
if (i > mid) {
arr[k] = temp[j++];
}
// 右边用尽
else if (j > hi) {
arr[k] = temp[i++];
}
// 左边大于右边
else if (temp[j] < temp[i]) {
arr[k] = temp[j++];
}
// 右边大于左边
else {
arr[k] = temp[i++];
}
}
}
public static void sort(int[] arr) {
sort(arr, 0, arr.length - 1);
}
public static void sort(int[] arr, int lo, int hi) {
if (hi <= lo) {
return;
}
int mid = lo + (hi - lo) / 2;
// 左半边排序
sort(arr, lo, mid);
// 右半边排序
sort(arr, mid + 1, hi);
// 归并
merge(arr, lo, mid, hi);
}
}
打印耗时:2493
二。ForkJoin框架实现的归并排序
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
/**
* 功能描述 : ForkJoinPool实现归并排序
*
* @author Ziyear 2020-05-21 19:57
*/
public class ForkJoinMergeSort {
private static int MAX = 100000;
private static int[] arr = new int[MAX];
static {
Random random = new Random();
for (int i = 1; i <= MAX; i++) {
arr[i - 1] = random.nextInt(MAX);
}
}
public static void main(String[] args) {
long start = System.currentTimeMillis();
ForkJoinPool pool = new ForkJoinPool();
MergeSortTask task = new MergeSortTask(arr);
ForkJoinTask<int[]> taskResult = pool.submit(task);
try {
taskResult.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
long end = System.currentTimeMillis();
System.out.println(end - start);
}
static class MergeSortTask extends RecursiveTask<int[]> {
private int[] source;
public MergeSortTask(int[] source) {
this.source = source;
}
@Override
protected int[] compute() {
int length = source.length;
// 如果条件成立,说明任务中要进行排序的集合还不够小
if (length > 2) {
int mid = length / 2;
// 拆分成两个子任务
MergeSortTask task1 = new MergeSortTask(Arrays.copyOf(source, mid));
task1.fork();
MergeSortTask task2 = new MergeSortTask(Arrays.copyOfRange(source, mid, length));
task2.fork();
// 将两个有序的数组,合并成一个有序的数组
int[] result1 = task1.join();
int[] result2 = task2.join();
return joinInts(result1, result2);
}
// 集合中只有一个或者两个元素,可以进行这两个元素的比较排序了
else {
if (length == 1 || source[0] < source[1]) {
return source;
} else {
int[] temp = new int[length];
temp[0] = source[1];
temp[1] = source[0];
return temp;
}
}
}
/**
* 功能描述 : 将两个有序数组合并起来
*
* @param arr1 数组1
* @param arr2 数组2
* @return {@link int[]}
* @author Ziyear 2020-5-21 20:12
*/
private int[] joinInts(int[] arr1, int[] arr2) {
int left = 0;
int right = 0;
int[] mergeArr = new int[arr1.length + arr2.length];
if (mergeArr.length == 0) {
return null;
}
for (int i = 0; i < arr1.length + arr2.length; i++) {
if (arr1.length == left) {
mergeArr[i] = arr2[right];
right++;
continue;
} else if (arr2.length == right) {
mergeArr[i] = arr1[left];
left++;
continue;
}
if (arr1[left] <= arr2[right]) {
mergeArr[i] = arr1[left];
left++;
} else {
mergeArr[i] = arr2[right];
right++;
}
}
return mergeArr;
}
}
}
打印耗时:64