如何用多线程实现归并排序

等我有时间了,一定要把《算法导论》啃完,这本书的印刷质量实在太好了,滑稽。

之前听吴恩达老大说过Python里面的Numpy包的矩阵运算就是多线程的,所以能做到的情况下尽量用矩阵运算代替循环,这样能大大加快运算的速度。

为了提高速度,如果不涉及外部资源读取的话,要提高运行速度就要做到并行计算,依赖于处理器的数量;如果需要等待耗时的外部资源读取,就可以通过并发边读边运算。

算法导论有一章节提到了并行循环,多线程矩阵乘法和多线程归并排序,方法都是讲一个大的计算过程分成几个独立的小部分,各个部分让单独的线程去计算。

排序里面讲问题分解的典型的就有快排和归并,接下来看一下怎么写多线程的。

单线程的排序算法可见另外一篇文章:https://blog.csdn.net/whut2010hj/article/details/80786831

多线程归并排序

直接点的思考方式,归并排序先要把一个数据分成两个,然后这两个分别归并排序,拍完了把两个归并到一起,典型的递归。

那么我们直接点,先把数组分割好,然后开两个线程,一个线程给一个,等着两个线程都搞定了,在把两个结果合并起来。或者你觉得两个线程每个要处理的还是太长了,那就在这两个线程里面再把拿到的数组分割了,各自再开两个。尝试一下

先看下单线程的版本,做下测试

import java.util.Random;

public class Main {
    public static void main(String[] args) {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        mergeSort(data);
        printArr(data);
    }

    //递归
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合并
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //打印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }


}


/**
 * 产生随机数据
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }


}

可以看到算法是能正常运行的

按上面思路的多线程版本呢?用 两个线程试验了下

只修改了main函数,加入了一个verify用作验证排序是不是OK的,不能人眼看吧

import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class Main {
    public static void main (String[] args) throws InterruptedException {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        // mergeSort(data);
        //在这里修改
        int center = data.length/2;

        int[] tmp = new int[data.length];
        CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能够使一个线程在等待另            
                                                    //外一些线程完成各自工作之后,再继续执行
        new Thread(new Runnable(){
        
            @Override
            public void run() {
                mergeSort(data,tmp,0,center);
                latch.countDown();
            }
        }).start();

        new Thread(new Runnable(){
        
            @Override
            public void run() {
                mergeSort(data,tmp,center+1,data.length-1);
                latch.countDown();
            }
        }).start();

        latch.await();

        merge(data, tmp, 0, center+1, data.length-1);

        printArr(data);
        System.out.println();
        verify(data);
    }

    //递归
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合并
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //打印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }
    
    public static void verify(int[] nums) {
        for(int i=0;i<nums.length-1;i++){
            if(nums[i]>nums[i+1]){
                System.out.println("排序失败");
                return;
            } 

        }
        System.out.println("排序成功");
    }
    


}


/**
 * 产生随机数据
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }

}

结果是OK的

上面是按自己的构思开启的线程。

 

其实Java本身提供了更好的解决方案,就是Fork/Join框架, 贴一下这个框架的介绍

Doug Lea 大神写的Fork/Join框架的论文

并发编程网翻译

使用Fork/Join 我们需要知道两个类:

  • ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:
    • RecursiveAction:用于没有返回结果的任务。
    • RecursiveTask :用于有返回结果的任务。
  • ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

下面看下如何用这个框架实现多线程归并排序

import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class Main {
    public static void main (String[] args) throws InterruptedException {
        int length = 1000;
        int[] data = (new Data(length)).getData();
        printArr(data);
        System.out.println();
        // mergeSort(data);
        //在这里修改
        // int center = data.length/2;

        int[] tmp = new int[data.length];
        // CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能够使一个线程在等待另外一些线程完成各自工作之后,再继续执行
        // new Thread(new Runnable(){
        
        //     @Override
        //     public void run() {
        //         mergeSort(data,tmp,0,center);
        //         latch.countDown();
        //     }
        // }).start();

        // new Thread(new Runnable(){
        
        //     @Override
        //     public void run() {
        //         mergeSort(data,tmp,center+1,data.length-1);
        //         latch.countDown();
        //     }
        // }).start();

        // latch.await();

        // merge(data, tmp, 0, center+1, data.length-1);

        //Fork/Join 从这里开始
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Main.mergeTask task = new Main.mergeTask(data, tmp, 0, data.length-1);//创建任务
        forkJoinPool.execute(task);//执行任务
        forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);//阻塞当前线程直到pool中的任务都完成了

        printArr(data);
        System.out.println();
        verify(data);

    }

    //递归
    private static void mergeSort(int[] nums,int[] tmp,int left,int right){
        if(left<right){
            int center = (left+right)/2;
            mergeSort(nums,tmp,left,center);
            mergeSort(nums,tmp,center+1,right);
            merge(nums,tmp,left,center+1,right);
        }
    }

    //合并
    private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
        int leftEnd = rightPos-1;
        int tmpPos = leftPos;
        int numElements = rightEnd - leftPos + 1;
    
        while(leftPos<=leftEnd&&rightPos<=rightEnd){
            if(nums[leftPos]<nums[rightPos])
                tmp[tmpPos++]=nums[leftPos++];
            else 
                tmp[tmpPos++]=nums[rightPos++];
        }
        while(leftPos<=leftEnd)
            tmp[tmpPos++]=nums[leftPos++];
        
        while(rightPos<=rightEnd)
            tmp[tmpPos++]=nums[rightPos++];
    
        for(int i = 0;i<numElements;i++,rightEnd--)
            nums[rightEnd]=tmp[rightEnd];
    }
    public static void mergeSort(int[] nums){
        int[] tmp = new int[nums.length];
        mergeSort(nums,tmp,0,nums.length-1);
    }
    
    //打印
    public static void printArr(int[] arr) {
        for(int i : arr){
            System.out.print(i+" ");
        }
    }
    
    public static void verify(int[] nums) {
        for(int i=0;i<nums.length-1;i++){
            if(nums[i]>nums[i+1]){
                System.out.println("排序失败");
                return;
            } 

        }
        System.out.println("排序成功");
    }
    

    static class mergeTask extends RecursiveAction {
        private static final int THRESHOLD = 2;//设置任务大小阈值
        private int start;
        private int end;
        private int[] data;
        private int[] tmp;
    
        public mergeTask(int[] data, int[] tmp, int start, int end){
            this.data = data;
            this.tmp = tmp;
            this.start = start;
            this.end = end;
        }
    
        @Override
        protected void compute(){
            if((end - start)<=THRESHOLD){
                mergeSort(data,tmp,start,end);
            }else{
                int center = (start + end)/2;
                Main.mergeTask leftTask = new Main.mergeTask(data, tmp, start, center);
                Main.mergeTask rightTask = new Main.mergeTask(data, tmp, center+1, end);

                leftTask.fork();
                rightTask.fork();

                leftTask.join();
                rightTask.join();

                merge(data, tmp, start, center+1, end);

            }
        }
    }

}


/**
 * 产生随机数据
 */
class Data{
    int length;
    int[] data;

    public Data(int length){
        this.length = length;
        data = new int[length];
    }

    public int[] getData(){

        Random random = new Random(System.currentTimeMillis());
        for(int i=0;i<length;i++){
            data[i]=random.nextInt(2*length);
        }
        return data;
    }

}

结果也是OK的

以上都没有涉及到锁,虽然操作的是共享的数组,但是被读写的区域是被隔离开的。

 

也是在算法导论上瞟到多线程算法这么一章,顺藤摸瓜才知道有Fork/Join 这个东西,要学的东西真的多。

 

搞完这个我又联想到之前看过的一道算法题:

在大量的数据中,寻找最大的k个数,或者是出现次数最多的k个数据,比如说这个数据有10个G,放在一个大文件中,电脑内存4G。

解题思路就是先把这个文件分块,为了确保相同的数据在一个块中,通过计算Hash值来分块,相同Hash 放到一个块中。比如每分100个块,这样平均一个块就在100M左右,对每个块分别载入内存找最大的前K个数或者出现最多的前K个数据,最后比较这100*K个数据来得到结果。

怎么用多线程求解?

 

 

 

  • 15
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值