Multiple thread for the MergeSort
0. 来源
在阅读 Introduction to Algorithm 中多线程一章后,决定自己实现 ++归并排序的多线程化++ 。本以为,它会很简单,而且并发后效果会比较好。然而,实现之路困难重重,并且效果很差。
故,把整个过程记录在此。。。
1. 过程
1.1 从“串行”的归并排序开始
比较简单直接上代码。
import java.util.Arrays; public class MergeSortSingle implements MergeSort { @Override public int[] mergeSortFunction(int[] array) { int[] myArray = Arrays.copyOf(array, array.length); mergeSort(myArray, 0, myArray.length - 1); return myArray; } private void mergeSort(int[] array, int s, int e) { if (s != e) { int mid = (s + e) / 2; mergeSort(array, s, mid); mergeSort(array, mid + 1, e); merge(array, s, mid, e); } } private void merge(int[] array, int s, int mid, int e) { int[] a1 = Arrays.copyOfRange(array, s, mid + 1); int[] a2 = Arrays.copyOfRange(array, mid + 1, e + 1); int j = 0; int k = 0; for (int i = s; i <= e; i++) { if (j >= a1.length) { array[i] = a2[k]; ++k; continue; } if (k >= a2.length) { array[i] = a1[j]; ++j; continue; } if (a1[j] < a2[k]) { array[i] = a1[j]; ++j; } else { array[i] = a2[k]; ++k; } } } }
1.2 直接多线程化
可以看到 归并排序 是一种典型的
分治算法
,它分为 分解 解决 合并 三个部分。我们相对归并排序多线程化,可一下想到从++分解之处++下手,也就是关于的mergeSort()
的两处调用可以并行执行。但是,需要注意的是:++考虑并发的同时,必须考虑同步。++ 比如,这里的
merge()
也就是合并过程,必须要等到前面两处分解结束后才能执行。OK,现在有了要并发的目标,实现上最困难之处有以下两点:
- 如何实现同步,也就是等待两处
mergeSort()
结束后才执行merge()
; - 如何管理线程,使用 线程池 or 无需管理;
从代码中可以看出,这里使用了
ExecutorService
提供的线程池来 管理线程。并且利用了ReentrantLock
&Condition
锁实现 同步。- 如何实现同步,也就是等待两处
在使用它们的时候遇到下面两个问题:
- 线程池的大小(若使用固定大小的线程池往往发生死锁,除非你设置的足够大。因为,此嵌套并行到回溯时才能一一的释放线程,应该画个图)
锁所在的位置(此锁的作用是为了等待。如果把锁作为类的属性,它会把所有的在该锁下等待的都唤醒。然而,它只应该唤醒的对应线程内的。所以,把锁作为类属性会扩大它的范围,导致问题)
import java.util.Arrays; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; public class MergeSortPool implements MergeSort { private ExecutorService executorService; // 锁必须在方法内 // private ReentrantLock reentrantLock = new ReentrantLock(); // private Condition condition = reentrantLock.newCondition(); public MergeSortPool() { // this.executorService = Executors.newFixedThreadPool(threadNum); this.executorService = Executors.newCachedThreadPool(); } @Override public int[] mergeSortFunction(int[] array) { int[] myArray = Arrays.copyOf(array, array.length); mergeSort(myArray, 0, myArray.length - 1); this.executorService.shutdown(); return myArray; } // TODO 若线程池设置的太小,则会导致死锁。因为,会出现把所有的线程 private void mergeSort(int[] array, int s, int e) { ReentrantLock reentrantLock = new ReentrantLock(); Condition condition = reentrantLock.newCondition(); AtomicBoolean needWait = new AtomicBoolean(true); if (s != e) { int mid = (s + e) / 2; this.executorService.execute(() -> { try { mergeSort(array, s, mid); } finally { try { reentrantLock.lock(); needWait.set(false); condition.signal(); } finally { reentrantLock.unlock(); } } }); mergeSort(array, mid + 1, e); try { reentrantLock.lock(); if (needWait.get()) condition.await(); } catch (InterruptedException e1) { e1.printStackTrace(); } finally { reentrantLock.unlock(); } merge(array, s, mid, e); } } private void merge(int[] array, int s, int mid, int e) { int[] a1 = Arrays.copyOfRange(array, s, mid + 1); int[] a2 = Arrays.copyOfRange(array, mid + 1, e + 1); int j = 0; int k = 0; for (int i = s; i <= e; i++) { if (j >= a1.length) { array[i] = a2[k]; ++k; continue; } if (k >= a2.length) { array[i] = a1[j]; ++j; continue; } if (a1[j] < a2[k]) { array[i] = a1[j]; ++j; } else { array[i] = a2[k]; ++k; } } } }
1.3 复杂多线程化
在运行了直接多线程化的后,会发现运行时间比串行的长很多。很惊讶,难道是因为它的并行度太低?于是,我开始尝试 复杂多线程化。
说白了,复杂多线程化就是也把归并排序中的合并过程多线程化。很显然,要重新设计 merge()
方法。
串行实现
新的
merge()
方法也是基于 分治策略 ,所以它上的并行也属于 嵌套并行。详细思路见代码(++说实话,这里走了一些弯路++)。// 该类为 “串行” 实现 import java.util.Arrays; public class MergeImprove implements MergeSort { private static int count = 0; // TODO you own code and other people's code, because thinking isn't clear. // TODO One thinking: take a correct element into a right position util the end of recursion. // TODO for the one time call to merge(), saveArray's relevant position is assigned only one time. // TODO however, MergeSort will call merge() multiple times, so if don't sync every time, it's going to have a thread safe problem. /** * merge a1 and a2, a1 and a2 is already sorted array * * @param s1 the start of a1 * @param e1 the end of a1 * @param s2 the start of a2 * @param e2 the end of a2 * @param array wait for merging array */ private void merge(int s1, int e1, int s2, int e2, int[] array, int start, int[] saveArray) { // 1, make a1 is bigger than a2 if (e2 - s2 > e1 - s1) { int temp = s1; s1 = s2; s2 = temp; temp = e1; e1 = e2; e2 = temp; } // 0, solve little problem if (e1 - s1 == -1) return; // 2, get median of the a1 int medianPos = s1 + (e1 - s1) / 2; int median = array[medianPos]; // 3, get pos and redistribute int pos = getPos(median, s2, e2, array); int newMedianPos = start + (medianPos - s1) + (pos - s2); System.out.println(count + " pos :: " + newMedianPos); saveArray[newMedianPos] = array[medianPos]; merge(s1, medianPos - 1, s2, pos - 1, array, start, saveArray); merge(medianPos + 1, e1, pos, e2, array, newMedianPos + 1, saveArray); } private static int getPos(int num, int start, int end, int[] array) { while (end >= start) { int mid = (end - start) / 2 + start; if (array[mid] >= num) { end = mid - 1; } else { start = mid + 1; } } return start; } @Override public int[] mergeSortFunction(int[] array) { int[] saveArray = new int[array.length]; mergeSort(array, 0, array.length - 1, saveArray); return saveArray; } private void mergeSort(int[] array, int s, int e, int[] saveArray) { if (s != e) { int mid = (s + e) / 2; int[] t = new int[array.length]; mergeSort(array, s, mid, t); mergeSort(array, mid + 1, e, t); ++count; merge(s, mid, mid + 1, e, t, s, saveArray); } else saveArray[s] = array[s]; } public static void main(String[] args) { int[] array = {2, 9, 1, 8, 2, 10, 18, 21, 12, 5, 7, 4, 3}; MergeImprove mergeImprove = new MergeImprove(); int[] saveArray = mergeImprove.mergeSortFunction(array); System.out.println(Arrays.toString(saveArray)); } }
并行实现
import java.util.Arrays; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; public class MergeImprovePool implements MergeSort { private ExecutorService executorService; public MergeImprovePool() { this.executorService = Executors.newCachedThreadPool(); } private void merge(int s1, int e1, int s2, int e2, int[] array, int start, int[] saveArray) { // 1, make a1 is bigger than a2 if (e2 - s2 > e1 - s1) { int temp = s1; s1 = s2; s2 = temp; temp = e1; e1 = e2; e2 = temp; } // 0, solve little problem if (e1 - s1 == -1) return; // 2, get median of the a1 int medianPos = (s1 + e1) / 2; int median = array[medianPos]; // 3, get pos and redistribute int pos = getPos(median, s2, e2, array); int newMedianPos = start + (medianPos - s1) + (pos - s2); saveArray[newMedianPos] = array[medianPos]; int finalS = s1; int finalS1 = s2; // note1 AtomicBoolean flag = new AtomicBoolean(false); this.executorService.execute(() -> { merge(finalS, medianPos - 1, finalS1, pos - 1, array, start, saveArray); // note1 flag.getAndSet(true); }); merge(medianPos + 1, e1, pos, e2, array, newMedianPos + 1, saveArray); // note1 while (!flag.get()) { try { Thread.sleep(10); } catch (InterruptedException e) { e.printStackTrace(); } } } /** * maybe should via class libraries * * @return position */ private static int getPos(int num, int start, int end, int[] array) { while (end >= start) { int mid = (end - start) / 2 + start; if (array[mid] >= num) { end = mid - 1; } else { start = mid + 1; } } return start; } @Override public int[] mergeSortFunction(int[] array) { int[] saveArray = new int[array.length]; mergeSort(array, 0, array.length - 1, saveArray); // note2 /* while (((ThreadPoolExecutor) this.executorService).getActiveCount() > 0) { try { Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); } } // */ this.executorService.shutdown(); return saveArray; } private void mergeSort(int[] array, int s, int e, int[] saveArray) { ReentrantLock reentrantLock = new ReentrantLock(); Condition condition = reentrantLock.newCondition(); AtomicBoolean needWait = new AtomicBoolean(true); if (s != e) { int mid = (s + e) / 2; int[] t = new int[array.length]; this.executorService.execute(() -> { try { mergeSort(array, s, mid, t); } finally { try { reentrantLock.lock(); needWait.set(false); condition.signal(); } finally { reentrantLock.unlock(); } } }); mergeSort(array, mid + 1, e, t); try { reentrantLock.lock(); if (needWait.get()) condition.await(); } catch (InterruptedException e1) { e1.printStackTrace(); } finally { reentrantLock.unlock(); } merge(s, mid, mid + 1, e, t, s, saveArray); } else saveArray[s] = array[s]; } public static void main(String[] args) { int[] array = {2, 9, 1, 8, 2, 10, 18, 21, 12, 5, 7, 4, 3}; MergeImprovePool mergeImprovePool = new MergeImprovePool(); int[] saveArray = mergeImprovePool.mergeSortFunction(array); System.out.println(Arrays.toString(saveArray)); } }
实现时,主要发现两个问题(分别在代码中标为
note1
note2
):- 为什么需要在
merge()
方法的最后添加同步? note2
处还需不需要?(不需要,若真正了明白上面问题这个问题就可以解决了)
针对
note1
的问题,理解过程如下:- 通过分析
merge()
方法,我们可以知道对它的一次调用只会对saveArray
数组的相应位置进行一次赋值。由此,我们可能会想到多线程下也是安全的;确实,若只对它进行一次调用,这个判断是对的; - 所以,一开始并没有在
merge()
方法最后添加同步sync
。也就是等待所有完成后,才退出。然而,在测试时发现结果中有好多 0,并且会在一定概率下成功~~~ - 一开始我以为是 ++线程池的问题++,线程管理出现问题,线程池关闭过早???经调试,发现所有的排序过程都执行了,这肯定不是线程池的锅了;
- 经过一下午的调试,终于发现少了
sync
操作,mmp!添加后(也就是note1
处),立马跑通!! - 为什么呢?MergeSort() will call merge() multiple times, so if don’t sync every time, it’s going to have a thread safe problem. 加上后,回溯 过程会一一调用
merge()
,而不会出现多个merge()
调用同时存在!
- 为什么需要在
2. 比较
方法 | 数量 | 时间 |
---|---|---|
MergeSort | 10000 | 0.025 |
MergeSortThread | 10000 | 3.127 |
MergeSortPool | 10000 | 1.018 |
MergeImprove | 10000 | 0.127 |
MergeImprovePool | 10000 | 2.207 |
MergeSort | 100000 | 0.043 |
MergeSortThread | 100000 | 230.621 |
MergeSortPool | 100000 | 2.791 |
MergeImprove | 100000 | 4.321 |
MergeImprovePool | 100000 | outofmemory |
3. 总结
结果分析,请参考。
任重而道远!!!