使用ForkJoin框架实现归并排序

Fork/Join

Fork/Join是一种采用“分治法”思想的框架,即将一个大问题拆分为无数个相同的小问题,然后各个击破再统一组装。jdk中的ForkJoinPool就是使用fork/join框架的多线程工具类(同其他线程池一样,扩展自抽象类AbstractExecutorService),我们只需要编写如何拆分任务的代码,它就能自动使用多线程去处理子任务,为我们屏蔽了线程的创建和使用细节,并且支持work−stealing(工作密取)。

工作密取(work−stealing)

在forkjoin中,任务拆分为子任务后,每个线程都有自己的子任务队列,当前线程执行完自己的任务队列后,会去查看其他线程的任务队列中是否有未执行的任务,若有则取出执行,确保每个线程都处于忙碌状态,最大限度提升性能。

ForkJoinPool的线程池大小是多少?如何修改?

public ForkJoinPool() {
    this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
         defaultForkJoinWorkerThreadFactory, null, false);
}

从ForkJoinPool的空参构造可以看到,默认使用的线程池大小为MAX_CAP和 Runtime.getRuntime().availableProcessors()取最小值,MAX_CAP的值为:

static final int MAX_CAP = 0x7fff;

转换为10进制就是2的15次方减一,而Runtime.getRuntime().availableProcessors()是一个native方法,返回为当前所在服务器的逻辑核心数。显然,ForkJoinPool默认的线程池大小就是服务器的逻辑核心数。

 

同时,可以在ForkJoinPool的构造函数中指定线程池的大小(只要指定的值大于0并且小于MAX_CAP值)

public ForkJoinPool(int parallelism) {
    this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}

 使用ForkJoinPool实现归并排序

归并排序算法也是采用了分治法的思想,为熟悉ForkJoin的用法,我们可以以实现归并排序为例,写一个多线程执行,并且支持工作密取的归并排序。

创建一个使用随机数填充的int数组,数组的默认长度为500万

public class ArraysUtils {

    private static final int DEFAULT_ARRAY_LENGTH = 5000000;

    private static Random random = new Random();

    public static int[] generateRandomArray(int arrayLength) {
        if (arrayLength <= 0) {
            arrayLength = DEFAULT_ARRAY_LENGTH;
        }
        int[] arr = new int[arrayLength];
        for (int i = 0; i < arrayLength; i++) {
            arr[i] = random.nextInt(arrayLength);
        }
        return arr;
    }

    public static int[] generateRandomArray() {
        return generateRandomArray(DEFAULT_ARRAY_LENGTH);
    }
}

在forkjoin中任务的抽象类有两个:RecursiveAction和RecursiveTask<T>,一个是不带返回值的任务,一个是带返回值的任务,显然我们需要拿到子任务的排序结果,所以使用RecursiveTask

public class SortTask extends RecursiveTask<int[]> {
    /**
     * 排序数组长度的阈值
     */
    private static final int THRESHOLD = 20;

    private static MergeSort mergeSort = new MergeSort();
    /**
     * 需要排序的数组
     */
    private int[] originArray;
    /**
     * 排序执行类
     */
    private SortActuator sortActuator;

    SortTask(int[] originArray, SortActuator sortActuator) {
        this.originArray = originArray;
        this.sortActuator = sortActuator;
    }

    @Override
    protected int[] compute() {
        if (originArray.length <= THRESHOLD) {
            // 数组长度<+阈值,则将数组进行排序
            return sortActuator.sort(originArray);
        } else {
            // 数组长度大于阈值,则继续将数组进行拆分
            int mid = originArray.length / 2;
            int[] leftArray = Arrays.copyOfRange(originArray, 0, mid);
            int[] rightArray = Arrays.copyOfRange(originArray, mid, originArray.length);
            SortTask leftTask = new SortTask(leftArray, sortActuator);
            SortTask rightTask = new SortTask(rightArray, sortActuator);
            invokeAll(leftTask, rightTask);
            // 将拆分后的数组排序并统计结果
            return mergeSort.merge(leftTask.join(), rightTask.join());
        }
    }
}

 Task中接收两个参数,一个是需要排序的数组,另一个是SortActuator对象,SortActuator是我们自己定义的具体进行排序的执行器接口,在任务拆分到合适的大小后,比如将500万大小的数组拆分为50万个长度为10的数组,可以使用SortActuator中的方法对单个数组进行排序。这里以最简单的冒泡排序为例。

public interface SortActuator {

    int[] sort(int[] originArray);
}
public class BubblingSortActuator implements SortActuator {

    @Override
    public int[] sort(int[] originArray) {
        for (int i = originArray.length - 1; i >= 0; i--) {
            for (int j = 0; j < i; j++) {
                if (originArray[j] > originArray[j + 1]) {
                    int temp = originArray[j + 1];
                    originArray[j + 1] = originArray[j];
                    originArray[j] = temp;
                }
            }
        }
        return originArray;
    }
}

在SortTask的compute方法中,如果数组的长度大于阈值,则拆分为两个数组,分别排序后合并结果,合并结果时调用的是MergeSort类的merge方法(归并排序中“归并”的关键)。

public class MergeSort {
    /**
     * 合并结果
     *
     * @param arrayLeft
     * @param arrayRight
     * @return
     */
    public int[] merge(int[] arrayLeft, int[] arrayRight) {
        int[] result = new int[arrayLeft.length + arrayRight.length];
        for (int index = 0, lIndex = 0, rIndex = 0; index < result.length; index++) {
            if (lIndex > arrayLeft.length - 1) {
                // 左边数组已经取完,直接将右边数组的数放进result数组中
                result[index] = arrayRight[rIndex++];
            } else if (rIndex > arrayRight.length - 1) {
                // 右边数组已经取完,直接将左边数组的数放进result数组中
                result[index] = arrayLeft[lIndex++];
            } else if (arrayLeft[lIndex] < arrayRight[rIndex]) {
                // 左边数组中index位置的值更小,则取左边数组中的数放入result数组中,并且将左数组指针+1
                result[index] = arrayLeft[lIndex++];
            } else {
                result[index] = arrayRight[rIndex++];
            }
        }
        return result;
    }
}

最后,测试一下结果,在控制台打印。

public class ForkJoinMergeTest {

    private static int sum = 0;

    public static void main(String[] args) {
        int[] originArray = ArraysUtils.generateRandomArray();
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        SortTask sortTask = new SortTask(originArray, new BubblingSortActuator());
        long startMillis = System.currentTimeMillis();
        forkJoinPool.execute(sortTask);
        int[] sortResult = sortTask.join();
        long expend = System.currentTimeMillis() - startMillis;
        Arrays.stream(sortResult).forEach(i -> {
            if (sum % 10 == 0) {
                System.out.println();
            }
            System.out.print(i + "<");
            sum++;
        });
        System.out.println();
        System.out.println("排序数组容量:" + originArray.length + ",共耗时:" + expend + "ms");
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值