Fork/Join 框架 详解

Fork/Join 框架 有时也称为 分解/合并框架,Fork/Join 框架采用分而治之将问题拆分成小问题。在一个任务中,先检查要解决问题的大小,如果大于设定,那就将问题拆分成可以通过框架来执行的小任务,如果问题的大小比设定的大小要小就直接在任务里解决这个问题,然后根据需要返回结果

Fork/Join 框架基于以下两种操作 :
分解(Fork) : 当需要讲一个任务拆分成更小的多个任务时,在框架中执行这些任务
合并(Join) : 当一个主任务等待其创建的多个子任务的完成执行

Fork/Join 框架执行任务的限制 :
1> 任务只能使用 fork() 和 join() 操作当前同步机制,如果使用其它同步机制,工作者线程就不能执行其他任务,当然这些任务是在同步操作里时。比如,如果在 Fork/Join 框架 中将一个任务休眠,正在执行这个任务的工作者线程在休眠期内不能执行另外一个任务
2> 任务不能执行 I/O操作,比如文件数据的读取与写入 
3> 任务不能抛出非运行时异常,必须在代码中处理掉这些异常

Fork/Join 框架两个核心组成类
1> ForkJoinPool : 该类实现 ExecutorService接口和工作窃取算法,它管理工作者线程,并提供任务的状态信息,以及任务的执行信息
2> ForkJoinTask : 是在 ForkJoinPool 中执行任务的基类

在实现 Fork/Join 框架任务,通常需要实现以下两个类之一的子类
1> RecursiveAction : 用于没有返回结果的场景,继承 ForkJoinTask 类
2> RecursiveTask : 用于任务有返回结果的场景,继承 ForkJoinTask 类

在 ForkJoinPool 中执行 ForkJoinTask 时,可以采用同步或者异步。采用同步方法执行时,发送任务给 Fork/Join 线程池的方法直接到任务执行完成后才会返回结果。而采用异步方法执行时,发送任务给执行器的方法将立即返回结果,但任务仍能继续执行


ForkJoinPool 类
该类看作成一个特殊的 Executor执行器类型,该类实现 ExecutorService接口和工作窃取算法

public ForkJoinPool() : 使用无参构造器创建对象,将使用默认配置,创建的线程数等于计算机CPU数目的线程池,创建好 ForkJoinPool对象 之后线程也就准备好了,在线程池中等待任务的到达,然后开始执行

public void execute(ForkJoinTask<?> task) : 安排提供异步执行任务,只要一设置就开始执行这个任务

public void execute(Runnable task) : 使用 Runnable 对象执行任务时,就不采用窃取算法执行任务

public <T> T invoke(ForkJoinTask<T> task) : execute方法 时异步调用,而 invoke方法 则是同步调用直到传递进来的任务执行结束后才会返回

public int getActiveThreadCount() : 返回一个预估正在窃取或执行任务的线程数量,这个方法也许估计高于活跃线程数

public long getStealCount() : 从另外一个线程工作队列中返回一个估计全部窃取任务数量。当线程池不沉寂时,返回的被低估实际窃取总数的值。这个值也许对监控有用而且协调Fork/Join 框架程序。通常,窃取数量应该足够高来保持线程繁忙,但足够低以避免线程的开销和争用

public int getParallelism() : 返回此池的目标并行级别

public long getQueuedTaskCount() : 返回工作线程当前在队列中执行的任务总数(但不包括提交到尚未执行的池的任务)。这个值只是一个近似,通过遍历池中的所有线程获得。该方法可用于优化任务粒度

public void shutdown() : 可能会初始化一个有序开关,在此之前提交的任务将被执行,但不会有新任务被接受。如果这是 commonPool(),调用对执行状态没有影响,如果已经关闭,则没有额外的影响。在此方法过程中同时提交的任务可能或不可能被拒绝

public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException : 在 shutdown() 方法执行后阻塞线程直到所有任务完成执行或者当前线程被阻塞,又或是到达指定时间。当申请公共池时,直到程序停止 commonPool() 才会终止,这个方法相当于 awaitQuiescence 方法但总是返回 false


ForkJoinTask 抽象类
public final boolean isDone() : 如果任务完成,返回 true

public final boolean isCompletedNormally() : 如果完成任务没有抛出异常且没有取消,返回

public final V get() throws InterruptedException, ExecutionException : 如果有必要等待计算完成,然后获取结果,这个方法返回 RecursiveTask 的 compute() 计算结果

public final V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException : 如果有必要等待到在最大给定的时间内完成计算,然后如果可以则获取结果。如果超时计算还未完成,那就返回 null 值

public static void invokeAll(ForkJoinTask<?>... tasks)
public static void invokeAll(ForkJoinTask<?> t1, ForkJoinTask<?> t2) : invokeAll方法执行一个主任务所创建的多个子任务,这是一个同步调用,这个任务将等待子任务完成然后继续执行也可能结束。当一个主任务等待它的子任务时,执行这个主任务的工作者线程接收另一个等待执行的任务并开始执行
 分叉指定的任务,当 isDone方法 保存每个任务或遇到一个 (未检查的)异常时返回,在这种情况下,异常被重新抛出。如果超过一个任务遇到异常,此时这个方法抛出这些异常中任意一个。如果任何任务遇到异常,其他任务可能被取消。可是,每个任务的执行状态在异常返回时无法保证。可以使用 getException() 和相关方法获得每个任务的状态,检查它们是否被取消、正常完成或异常,或未处理

public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks) : 在特定集合中分开所有任务,当 isDone方法 保存每个任务或遇到一个 (未检查的)异常时返回,在这种情况下,异常被重新抛出。如果超过一个任务遇到异常,此时这个方法抛出这些异常中任意一个。如果任何任务遇到异常,其他任务可能被取消。可是,每个任务的执行状态在异常返回时无法保证。可以使用 getException() 和相关方法获得每个任务的状态,检查它们是否被取消、正常完成或异常,或未处理

public final ForkJoinTask<V> fork() : 在当前任务的池中,如果可获取,则在当前任务中执行该任务,或者使用 commonPool()方法如果不能获取 。虽然它不是强制执行的,但如果它已经完成并被重新初始化,那么它将不止一次地使用一个任务。随后对该任务状态的修改或其操作的任何数据,在执行它之前的任何线程中都不一定能始终观察到,除非在调用 join方法 或相关方法之前,或者调用 isDone() 返回 true
说白了就是若有空闲工作者线程或者创建新线程来执行任务,执行任务不会阻塞当前线程会直接返回 ForkJoinTask 来存放执行结果,此时住线程可运行其它任务

public final V join() : 当 isDone() 方法完成时返回计算结果,否则会一直等待计算完成。这个方法不同于 get() 方法 因为 在RuntimeException 非正常完成结果 或 没有ExecutionException 的 Error,此时 调用线程的中断不会因为抛出 InterruptedException 而导致方法突然返回
join() 方法不能被中断,若中断调用 join()方法的线程,方法将会抛出 InterruptedException

public final Throwable getException() : 返回在计算过程中抛出的异常,或者如果任务取消返回 CancellationException,或者如果 没有异常或者任务没有完成返回 null

public final boolean isCompletedNormally() : 如果任务完成没有抛出异常且没有取消,返回 true

public final boolean isCompletedAbnormally() : 如果任务抛出异常或取消,返回 true

public boolean cancel(boolean mayInterruptIfRunning) : 尝试取消执行的这个任务。如果任务已经完成或者某些原因不能取消尝试将会失败。如果取消任务成功而且当取消任务时这个任务还未开始,这个执行的任务将会停止。在这个方法返回成功后,除非调用 reinitialize() 随后调用 isCancelled() isDone() 且返回 true 而且调用 join 相关方法结果会是 CancellationException。这个方法可能在子类中被覆盖,但如果是这样的话,仍然必须确保这些属性保持不变。特别是,cancel方法本身不能抛出异常。这个方法设计的目的就是被其他任务调用。对于计算方法,为终止当前任务,你可以返回或者抛出 unchecked exception 或者 调用 completeExceptionally(Throwable) 方法
简单说,取消一个没有执行的任务,如果已经开始执行则无法取消,参数 mayInterruptIfRunning 为 true 表示这个任务即是在运行也将被取消

Fork/Join 框架的局限在于 ForkJoinPool线程中的任务不允许被取消


RecursiveAction 将任务分解
public class Product {
    private String name ;
    private double price ;
    public String getName() {
        return name ;
    }
    public void setName(String name) {
        this . name = name;
    }
    public double getPrice() {
        return price ;
    }
    public void setPrice( double price) {
        this . price = price;
    }
}

public class Task extends RecursiveAction {
    private static final long serialVersionUID = - 3009073490385728376L ;
    private List<Product> products ;
    private int first ;
    private int last ;
    private double increment ;
    public Task(List<Product> products, int first, int last, double increment) {
        this . products = products;
        this . first = first;
        this . last = last;
        this . increment = increment;
    }
    @Override
    protected void compute() {
        if ( last - first < 10 ) {
            updatePrices();
        } else {
            int middle = ( last + first ) / 2 ;
            System. out .printf( "Task: Pending tasks:%s \n " , getQueuedTaskCount ());
            Task t1 = new Task( products , first , middle + 1 , increment );
            Task t2 = new Task( products , middle + 1 , last , increment );
            invokeAll (t1, t2);
        }
    }
    private void updatePrices() {
        for ( int i = first ; i < last ; i++) {
            Product product = products .get(i);
            product.setPrice(product.getPrice() * (i + increment ));
        }
    }
    public static List<Product> generate( int size) {
        List<Product> ret = new ArrayList<>();
        for ( int i = 0 ; i < size; i++) {
            Product product = new Product();
            product.setName( "Product " + i);
            product.setPrice( 10 );
            ret.add(product);
        }
        return ret;
    }
    public static void main(String[] args) {
        List<Product> products = generate( 10000 );
        Task task = new Task(products, 0 , products.size(), 0.2 );
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        forkJoinPool.execute(task);
        try {
            do {
                System. out .printf( "Main: Thread Count: %d \n " , forkJoinPool.getActiveThreadCount());
                System. out .printf( "Main: Thread Steal: %d \n " , forkJoinPool.getStealCount());
                System. out .printf( "Main: Thread Parallelism: %d \n " , forkJoinPool.getParallelism());
                TimeUnit. MILLISECONDS .sleep( 5 );
            } while (!task.isDone());
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        forkJoinPool.shutdown();
        if (task.isCompletedNormally()) {
            System. out .printf( "Main: The process has completed normally. \n " );
        }
        for ( int i = 0 ; i < products.size(); i++) {
            Product product = products.get(i);
            if (product.getPrice() != 12 ) {
                System. out .printf( "Produt %s: %f \n " , product.getName(), product.getPrice());
            }
        }
        System. out .printf( "Main: End of the program. \n " );
    }
}

合并任务结果
public class Document {
    private String words [] = { "the" , "hello" , "goodbye" , "packt" , "java" , "thread" , "pool" ,   "random" , "class" , "main" };
    public String[][] generateDocument( int numLines, int numWords, String word) {
        int counter = 0 ;
        String document[][] = new String[numLines][numWords];
        Random random = new Random();
        for ( int i = 0 ; i < numLines; i++) {
            for ( int j = 0 ; j < numWords; j++) {
                int index = random.nextInt( words . length );
                document[i][j] = words [index];
                if (document[i][j].equals(word)) {
                    counter++;
                }
            }
        }
        System. out .println( "DocumentMock: The word appears " + counter + " times in the document" );
        return document;
    }
}

public class LineTask extends RecursiveTask<Integer> {
    private static final long serialVersionUID = 8771633587730430578L ;
    private String line [];
    private int start , end ;
    private String word ;
    public LineTask(String[] line, int start, int end, String word) {
        this . line = line;
        this . start = start;
        this . end = end;
        this . word = word;
    }
    @Override
    protected Integer compute() {
        Integer result = null ;
        if ( end - start < 100 ) {
            result = count( line , start , end , word );
        } else {
            int mid = ( start + end ) / 2 ;
            LineTask task1 = new LineTask( line , start , mid, word );
            LineTask task2 = new LineTask( line , mid, end , word );
            invokeAll (task1, task2);
            try {
                result = groupResults(task1.get(), task2.get());
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return result;
    }
    private Integer count(String[] line, int start, int end, String word) {
        int counter = 0 ;
        for ( int i = start; i < end; i++) {
            if (line[i].equals(word)) {
                counter++;
            }
        }
        try {
            Thread. sleep ( 10 );
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return counter;
    }
    private Integer groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }
}

public class DocumentTask extends RecursiveTask<Integer> {
    private String document [][];
    private int start , end ;
    private String word ;
    public DocumentTask(String[][] document, int start, int end, String word) {
        this . document = document;
        this . start = start;
        this . end = end;
        this . word = word;
    }
    @Override
    protected Integer compute() {
        int result = 0 ;
        if ( end - start < 10 ) {
            result = processLines( document , start , end , word );
        } else {
            int mid = ( start + end ) / 2 ;
            DocumentTask task1 = new DocumentTask( document , start , mid, word );
            DocumentTask task2 = new DocumentTask( document , mid, end , word );
            invokeAll (task1, task2);
            try {
                result = groupResults(task1.get(), task2.get());
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return result;
    }
    private Integer processLines(String[][] document, int start, int end, String word) {
        List<LineTask> tasks = new ArrayList<>();
        for ( int i = start; i < end; i++) {
            LineTask lineTask = new LineTask(document[i], 0 , document[i]. length , word);
            tasks.add(lineTask);
        }
        invokeAll (tasks);
        int result = 0 ;
        try {
            for (LineTask lineTask : tasks) {
                result = result + lineTask.get();
            }
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
        return result;
    }
    private Integer groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }
    public static void main(String[] args) {
        Document mock = new Document();
        String[][] document = mock.generateDocument( 100 , 1000 , "the" );
        DocumentTask task = new DocumentTask(document, 0 , 100 , "the" );
        ForkJoinPool pool = new ForkJoinPool();
        pool.execute(task);
        do {
            System. out .printf( "Main: Thread Parallelism: %d \n " , pool.getParallelism());
            System. out .printf( "Main: Thread Active Threads: %d \n " , pool.getActiveThreadCount());
            System. out .printf( "Main: Thread Task Count: %d \n " , pool.getQueuedTaskCount());
            System. out .printf( "Main: Thread Steal Count: %d \n " , pool.getStealCount());
            try {
                TimeUnit. MILLISECONDS .sleep( 5 );
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
        pool.shutdown();
        try {
            pool.awaitTermination( 1 , TimeUnit. DAYS );
            System. out .printf( "Main: The word appears %d in the document" , task.get());
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
        }
    }
}

异步运行任务
public class FolderProcessor extends RecursiveTask<List<String>> {
    private static final long serialVersionUID = 8145504519186061283L ;
    private String path ;
    private String extension ;
    public FolderProcessor(String path, String extension) {
        this . path = path;
        this . extension = extension;
    }
    @Override
    protected List<String> compute() {
        List<String> filePathList = new ArrayList<>();
        List<FolderProcessor> tasks = new ArrayList<>();
        File file = new File( path );
        File content[] = file.listFiles();
        if (content != null ) {
            for ( int i = 0 ; i < content. length ; i++) {
                if (content[i].isDirectory()) {
                    FolderProcessor task = new FolderProcessor(content[i].getAbsolutePath(), extension );
                    task.fork();
                    tasks.add(task);
                } else {
                    if (checkFile(content[i].getName())) {
                        filePathList.add(content[i].getAbsolutePath());
                    }
                }
            }
            if (tasks.size() > 50 ) {
                System. out .printf( "%s: %d tasks ran. \n " , file.getAbsolutePath(), tasks.size());
            }
            addResultsFromTasks(filePathList, tasks);
        }
        return filePathList;
    }
    private void addResultsFromTasks(List<String> filePathList, List<FolderProcessor> tasks) {
//        tasks.forEach((item) -> list.addAll(item.join())); 使用 Lambda 一句话完成任务
        for (FolderProcessor folderProcessor : tasks) {
            filePathList.addAll(folderProcessor.join());
        }
    }
    private boolean checkFile(String name) {
        return name.endsWith( extension );
    }
    public static void main(String[] args) {
        ForkJoinPool pool = new ForkJoinPool();
        FolderProcessor system = new FolderProcessor( "/Users/mew/Desktop/AllmyFile" , "log" );
        FolderProcessor apps = new FolderProcessor( "/Users/mew/Desktop/leanote-master" , "log" );
        FolderProcessor documents = new FolderProcessor( "/Users/mew/Desktop/Company_Workspace" , "log" );
        pool.execute(system);
        pool.execute(apps);
        pool.execute(documents);
        do {
            System. out .printf( "****************************** \n " );
            System. out .printf( "Main: Parallelism: %d \n " , pool.getParallelism());
            System. out .printf( "Main: Active Threads: %d \n " , pool.getActiveThreadCount());
            System. out .printf( "Main: Task Count: %d \n " , pool.getQueuedTaskCount());
            System. out .printf( "Main: Steal Count: %d \n " , pool.getStealCount());
            try {
                TimeUnit. SECONDS .sleep( 1 );
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while ((!system.isDone()) || (!apps.isDone()) || (!documents.isDone()));
        pool.shutdown();
        List<String> results = system.join();
        System. out .printf( "System: %d files found. \n " , results.size());
        results = apps.join();
        System. out .printf( "Apps: %d files found. \n " , results.size());
        results = documents.join();
        System. out .printf( "Documents: %d files found. \n " , results.size());
    }
}

取消任务
public class TaskManager {
    private List<ForkJoinTask<Integer>> tasks ;
    public TaskManager() {
        tasks = new ArrayList<>();
    }
    public void addTask(ForkJoinTask<Integer> task) {
        tasks .add(task);
    }
    public void cancleTasks(ForkJoinTask<Integer> cancelTask) {
        for (ForkJoinTask<Integer> task : tasks ) {
            if (task != cancelTask) {
                task.cancel( true );
                ((SearchNumberTask) task).writeCancelMessage();
            }
        }
    }
}

public class SearchNumberTask extends RecursiveTask<Integer> {
    private int [] numbers ;
    private int start , end ;
    private int number ;
    private TaskManager manager ;
    private final static int NOT_FOUND = - 1 ;
    public SearchNumberTask( int [] numbers, int start, int end, int number, TaskManager manager) {
        this . numbers = numbers;
        this . start = start;
        this . end = end;
        this . number = number;
        this . manager = manager;
    }
    @Override
    protected Integer compute() {
        System. out .println( "Task: " + start + ":" + end );
        int ret;
        if ( end - start > 10 ) {
            ret = launchTasks();
        } else {
            ret = lookForNumber();
        }
        return ret;
    }
    private int lookForNumber() {
        for ( int i = start ; i < end ; i++) {
            if ( numbers [i] == number ) {
                System. out .printf( "Task: Number %d found in position %d \n " , number , i);
                manager .cancleTasks( this );
                return i;
            }
            try {
                TimeUnit. SECONDS .sleep( 1 );
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        return NOT_FOUND ;
    }
    private int launchTasks() {
        int mid = ( start + end ) / 2 ;
        SearchNumberTask task1 = new SearchNumberTask( numbers , start , mid, number , manager );
        SearchNumberTask task2 = new SearchNumberTask( numbers , mid, end , number , manager );
        manager .addTask(task1);
        manager .addTask(task2);
        task1.fork();
        task2.fork();
        int returnValue = task1.join();
        if (returnValue != - 1 ) {
            return returnValue;
        }
        returnValue = task2.join();
        return returnValue;
    }
    public void writeCancelMessage() {
        System. out .printf( "Task: Cancelled task from %d to %d" , start , end );
    }
    public static int [] generateArray( int size) {
        int [] array = new int [size];
        Random random = new Random();
        for ( int i = 0 ; i < size; i++) {
            array[i] = random.nextInt( 10 );
        }
        return array;
    }
    public static void main(String[] args) {
        int [] array = generateArray ( 1000 );
        TaskManager manager = new TaskManager();
        ForkJoinPool pool = new ForkJoinPool();

        SearchNumberTask task = new SearchNumberTask(array, 0 , 1000 , 5 , manager);
        pool.execute(task);
        pool.shutdown();
        try {
            pool.awaitTermination( 1 , TimeUnit. DAYS );
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值