归并排序多线程化

Multiple thread for the MergeSort

0. 来源

在阅读 Introduction to Algorithm 中多线程一章后,决定自己实现 ++归并排序的多线程化++ 。本以为,它会很简单,而且并发后效果会比较好。然而,实现之路困难重重,并且效果很差。

故,把整个过程记录在此。。。

1. 过程

1.1 从“串行”的归并排序开始
  • 比较简单直接上代码。

    import java.util.Arrays;
    
    public class MergeSortSingle implements MergeSort {
    
        @Override
        public int[] mergeSortFunction(int[] array) {
            int[] myArray = Arrays.copyOf(array, array.length);
            mergeSort(myArray, 0, myArray.length - 1);
            return myArray;
        }
    
        private void mergeSort(int[] array, int s, int e) {
            if (s != e) {
                int mid = (s + e) / 2;
                mergeSort(array, s, mid);
                mergeSort(array, mid + 1, e);
                merge(array, s, mid, e);
            }
        }
    
        private void merge(int[] array, int s, int mid, int e) {
            int[] a1 = Arrays.copyOfRange(array, s, mid + 1);
            int[] a2 = Arrays.copyOfRange(array, mid + 1, e + 1);
            int j = 0;
            int k = 0;
            for (int i = s; i <= e; i++) {
                if (j >= a1.length) {
                    array[i] = a2[k];
                    ++k;
                    continue;
                }
                if (k >= a2.length) {
                    array[i] = a1[j];
                    ++j;
                    continue;
                }
                if (a1[j] < a2[k]) {
                    array[i] = a1[j];
                    ++j;
                } else {
                    array[i] = a2[k];
                    ++k;
                }
            }
        }
    
    }
    
1.2 直接多线程化
  • 可以看到 归并排序 是一种典型的 分治算法 ,它分为 分解 解决 合并 三个部分。我们相对归并排序多线程化,可一下想到从++分解之处++下手,也就是关于的 mergeSort() 的两处调用可以并行执行。

  • 但是,需要注意的是:++考虑并发的同时,必须考虑同步。++ 比如,这里的 merge() 也就是合并过程,必须要等到前面两处分解结束后才能执行。

  • OK,现在有了要并发的目标,实现上最困难之处有以下两点:

    • 如何实现同步,也就是等待两处 mergeSort() 结束后才执行 merge()
    • 如何管理线程,使用 线程池 or 无需管理

    从代码中可以看出,这里使用了 ExecutorService 提供的线程池来 管理线程。并且利用了 ReentrantLock & Condition 锁实现 同步

  • 在使用它们的时候遇到下面两个问题:

    • 线程池的大小(若使用固定大小的线程池往往发生死锁,除非你设置的足够大。因为,此嵌套并行到回溯时才能一一的释放线程,应该画个图)
    • 锁所在的位置(此锁的作用是为了等待。如果把锁作为类的属性,它会把所有的在该锁下等待的都唤醒。然而,它只应该唤醒的对应线程内的。所以,把锁作为类属性会扩大它的范围,导致问题)

      import java.util.Arrays;
      import java.util.concurrent.ExecutorService;
      import java.util.concurrent.Executors;
      import java.util.concurrent.atomic.AtomicBoolean;
      import java.util.concurrent.locks.Condition;
      import java.util.concurrent.locks.ReentrantLock;
      
      public class MergeSortPool implements MergeSort {
      
          private ExecutorService executorService;
      
          // 锁必须在方法内
      //    private ReentrantLock reentrantLock = new ReentrantLock();
      //    private Condition condition = reentrantLock.newCondition();
      
          public MergeSortPool() {
      //        this.executorService = Executors.newFixedThreadPool(threadNum);
              this.executorService = Executors.newCachedThreadPool();
          }
      
          @Override
          public int[] mergeSortFunction(int[] array) {
              int[] myArray = Arrays.copyOf(array, array.length);
              mergeSort(myArray, 0, myArray.length - 1);
              this.executorService.shutdown();
              return myArray;
          }
      
          // TODO 若线程池设置的太小,则会导致死锁。因为,会出现把所有的线程
          private void mergeSort(int[] array, int s, int e) {
              ReentrantLock reentrantLock = new ReentrantLock();
              Condition condition = reentrantLock.newCondition();
              AtomicBoolean needWait = new AtomicBoolean(true);
              if (s != e) {
                  int mid = (s + e) / 2;
                  this.executorService.execute(() -> {
                      try {
                          mergeSort(array, s, mid);
                      } finally {
                          try {
                              reentrantLock.lock();
                              needWait.set(false);
                              condition.signal();
                          } finally {
                              reentrantLock.unlock();
                          }
                      }
                  });
                  mergeSort(array, mid + 1, e);
                  try {
                      reentrantLock.lock();
                      if (needWait.get())
                          condition.await();
                  } catch (InterruptedException e1) {
                      e1.printStackTrace();
                  } finally {
                      reentrantLock.unlock();
                  }
                  merge(array, s, mid, e);
              }
          }
      
          private void merge(int[] array, int s, int mid, int e) {
              int[] a1 = Arrays.copyOfRange(array, s, mid + 1);
              int[] a2 = Arrays.copyOfRange(array, mid + 1, e + 1);
              int j = 0;
              int k = 0;
              for (int i = s; i <= e; i++) {
                  if (j >= a1.length) {
                      array[i] = a2[k];
                      ++k;
                      continue;
                  }
                  if (k >= a2.length) {
                      array[i] = a1[j];
                      ++j;
                      continue;
                  }
                  if (a1[j] < a2[k]) {
                      array[i] = a1[j];
                      ++j;
                  } else {
                      array[i] = a2[k];
                      ++k;
                  }
              }
          }
      }
      
1.3 复杂多线程化

在运行了直接多线程化的后,会发现运行时间比串行的长很多。很惊讶,难道是因为它的并行度太低?于是,我开始尝试 复杂多线程化

说白了,复杂多线程化就是也把归并排序中的合并过程多线程化。很显然,要重新设计 merge() 方法。

  • 串行实现

    新的 merge() 方法也是基于 分治策略 ,所以它上的并行也属于 嵌套并行。详细思路见代码(++说实话,这里走了一些弯路++)。

    // 该类为 “串行” 实现
    import java.util.Arrays;
    
    public class MergeImprove implements MergeSort {
    
        private static int count = 0;
    
        // TODO you own code and other people's code, because thinking isn't clear.
        // TODO One thinking: take a correct element into a right position util the end of recursion.
    
        // TODO for the one time call to merge(), saveArray's relevant position is assigned only one time.
        // TODO however, MergeSort will call merge() multiple times, so if don't sync every time, it's going to have a thread safe problem.
    
        /**
         * merge a1 and a2, a1 and a2 is already sorted array
         *
         * @param s1    the start of a1
         * @param e1    the end of a1
         * @param s2    the start of a2
         * @param e2    the end of a2
         * @param array wait for merging array
         */
        private void merge(int s1, int e1, int s2, int e2, int[] array, int start, int[] saveArray) {
            // 1, make a1 is bigger than a2
            if (e2 - s2 > e1 - s1) {
                int temp = s1;
                s1 = s2;
                s2 = temp;
                temp = e1;
                e1 = e2;
                e2 = temp;
            }
            // 0, solve little problem
            if (e1 - s1 == -1)
                return;
            // 2, get median of the a1
            int medianPos = s1 + (e1 - s1) / 2;
            int median = array[medianPos];
            // 3, get pos and redistribute
            int pos = getPos(median, s2, e2, array);
    
            int newMedianPos = start + (medianPos - s1) + (pos - s2);
            System.out.println(count + " pos :: " + newMedianPos);
            saveArray[newMedianPos] = array[medianPos];
            merge(s1, medianPos - 1, s2, pos - 1, array, start, saveArray);
            merge(medianPos + 1, e1, pos, e2, array, newMedianPos + 1, saveArray);
        }
    
        private static int getPos(int num, int start, int end, int[] array) {
            while (end >= start) {
                int mid = (end - start) / 2 + start;
                if (array[mid] >= num) {
                    end = mid - 1;
                } else {
                    start = mid + 1;
                }
            }
            return start;
        }
    
        @Override
        public int[] mergeSortFunction(int[] array) {
            int[] saveArray = new int[array.length];
            mergeSort(array, 0, array.length - 1, saveArray);
            return saveArray;
        }
    
        private void mergeSort(int[] array, int s, int e, int[] saveArray) {
            if (s != e) {
                int mid = (s + e) / 2;
                int[] t = new int[array.length];
                mergeSort(array, s, mid, t);
                mergeSort(array, mid + 1, e, t);
                ++count;
                merge(s, mid, mid + 1, e, t, s, saveArray);
            } else
                saveArray[s] = array[s];
        }
    
        public static void main(String[] args) {
            int[] array = {2, 9, 1, 8, 2, 10, 18, 21, 12, 5, 7, 4, 3};
            MergeImprove mergeImprove = new MergeImprove();
            int[] saveArray = mergeImprove.mergeSortFunction(array);
            System.out.println(Arrays.toString(saveArray));
        }
    }
    
  • 并行实现

    import java.util.Arrays;
    import java.util.concurrent.ExecutorService;
    import java.util.concurrent.Executors;
    import java.util.concurrent.atomic.AtomicBoolean;
    import java.util.concurrent.locks.Condition;
    import java.util.concurrent.locks.ReentrantLock;
    
    public class MergeImprovePool implements MergeSort {
    
        private ExecutorService executorService;
    
        public MergeImprovePool() {
            this.executorService = Executors.newCachedThreadPool();
        }
    
        private void merge(int s1, int e1, int s2, int e2, int[] array, int start, int[] saveArray) {
            // 1, make a1 is bigger than a2
            if (e2 - s2 > e1 - s1) {
                int temp = s1;
                s1 = s2;
                s2 = temp;
                temp = e1;
                e1 = e2;
                e2 = temp;
            }
            // 0, solve little problem
            if (e1 - s1 == -1)
                return;
            // 2, get median of the a1
            int medianPos = (s1 + e1) / 2;
            int median = array[medianPos];
            // 3, get pos and redistribute
            int pos = getPos(median, s2, e2, array);
            int newMedianPos = start + (medianPos - s1) + (pos - s2);
            saveArray[newMedianPos] = array[medianPos];
            int finalS = s1;
            int finalS1 = s2;
            // note1
            AtomicBoolean flag = new AtomicBoolean(false);
            this.executorService.execute(() -> {
                merge(finalS, medianPos - 1, finalS1, pos - 1, array, start, saveArray);
                // note1
                flag.getAndSet(true);
            });
            merge(medianPos + 1, e1, pos, e2, array, newMedianPos + 1, saveArray);
            // note1
            while (!flag.get()) {
                try {
                    Thread.sleep(10);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    
        /**
         * maybe should via class libraries
         *
         * @return position
         */
        private static int getPos(int num, int start, int end, int[] array) {
            while (end >= start) {
                int mid = (end - start) / 2 + start;
                if (array[mid] >= num) {
                    end = mid - 1;
                } else {
                    start = mid + 1;
                }
            }
            return start;
        }
    
        @Override
        public int[] mergeSortFunction(int[] array) {
            int[] saveArray = new int[array.length];
            mergeSort(array, 0, array.length - 1, saveArray);
            // note2
            /*
            while (((ThreadPoolExecutor) this.executorService).getActiveCount() > 0) {
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // */
            this.executorService.shutdown();
    
            return saveArray;
        }
    
        private void mergeSort(int[] array, int s, int e, int[] saveArray) {
            ReentrantLock reentrantLock = new ReentrantLock();
            Condition condition = reentrantLock.newCondition();
            AtomicBoolean needWait = new AtomicBoolean(true);
            if (s != e) {
                int mid = (s + e) / 2;
                int[] t = new int[array.length];
                this.executorService.execute(() -> {
                    try {
                        mergeSort(array, s, mid, t);
                    } finally {
                        try {
                            reentrantLock.lock();
                            needWait.set(false);
                            condition.signal();
                        } finally {
                            reentrantLock.unlock();
                        }
                    }
                });
                mergeSort(array, mid + 1, e, t);
                try {
                    reentrantLock.lock();
                    if (needWait.get())
                        condition.await();
                } catch (InterruptedException e1) {
                    e1.printStackTrace();
                } finally {
                    reentrantLock.unlock();
                }
                merge(s, mid, mid + 1, e, t, s, saveArray);
            } else
                saveArray[s] = array[s];
        }
    
        public static void main(String[] args) {
            int[] array = {2, 9, 1, 8, 2, 10, 18, 21, 12, 5, 7, 4, 3};
            MergeImprovePool mergeImprovePool = new MergeImprovePool();
            int[] saveArray = mergeImprovePool.mergeSortFunction(array);
            System.out.println(Arrays.toString(saveArray));
        }
    }
    

    实现时,主要发现两个问题(分别在代码中标为 note1 note2):

    • 为什么需要在 merge() 方法的最后添加同步?
    • note2 处还需不需要?(不需要,若真正了明白上面问题这个问题就可以解决了)

    针对 note1 的问题,理解过程如下:

    1. 通过分析 merge() 方法,我们可以知道对它的一次调用只会对 saveArray 数组的相应位置进行一次赋值。由此,我们可能会想到多线程下也是安全的;确实,若只对它进行一次调用,这个判断是对的;
    2. 所以,一开始并没有在 merge() 方法最后添加同步 sync。也就是等待所有完成后,才退出。然而,在测试时发现结果中有好多 0,并且会在一定概率下成功~~~
    3. 一开始我以为是 ++线程池的问题++,线程管理出现问题,线程池关闭过早???经调试,发现所有的排序过程都执行了,这肯定不是线程池的锅了;
    4. 经过一下午的调试,终于发现少了 sync 操作,mmp!添加后(也就是 note1 处),立马跑通!!
    5. 为什么呢?MergeSort() will call merge() multiple times, so if don’t sync every time, it’s going to have a thread safe problem. 加上后,回溯 过程会一一调用 merge(),而不会出现多个 merge() 调用同时存在!

2. 比较

方法数量时间
MergeSort100000.025
MergeSortThread100003.127
MergeSortPool100001.018
MergeImprove100000.127
MergeImprovePool100002.207
MergeSort1000000.043
MergeSortThread100000230.621
MergeSortPool1000002.791
MergeImprove1000004.321
MergeImprovePool100000outofmemory

3. 总结

结果分析,请参考

任重而道远!!!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值