简介

从JDK1.7开始,Java提供ForkJoin框架用于并行执行任务,它的思想就是将一个大任务分割成若干小任务,最终汇总每个小任务的结果得到这个大任务的结果。

整个流程需要三个类完成

1、ForkJoinPool

既然任务是被逐渐的细化的,那就需要把这些任务存在一个池子里面,这个池子就是ForkJoinPool。

它与其它的ExecutorService区别主要在于它使用“工作窃取“,那什么是工作窃取呢?

一个大任务会被划分成无数个小任务,这些任务被分配到不同的队列,这些队列有些干活干的块,有些干得慢。于是干得快的,一看自己没任务需要执行了,就去隔壁的队列里面拿去任务执行。

2、ForkJoinTask

ForkJoinTask就是ForkJoinPool里面的每一个任务。他主要有两个子类:RecursiveActionRecursiveTask。然后通过fork()方法去分配任务执行任务,通过join()方法汇总任务结果,

  • RecursiveAction 一个递归无结果的ForkJoinTask(没有返回值)

  • RecursiveTask 一个递归有结果的ForkJoinTask(有返回值)

需要注意的是:<mark>这两个子类都是抽象类,需要继承实现。</mark>

举例说明

我们举个例子:如果要计算一个超大数组的和,最简单的做法是用一个循环在一个线程内完成:

┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

还有一种方法,可以把数组拆成两部分,分别计算,最后加起来就是最终结果,这样可以用两个线程并行执行:

┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

如果拆成两部分还是很大,我们还可以继续拆,用4个线程并行执行:

┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘

这就是Fork/Join任务的原理:判断一个任务是否足够小,如果是,直接计算,否则,就分拆成几个小任务分别计算。这个过程可以反复“裂变”成一系列小任务。

编码实现

整个任务流程如下所示

  • 首先继承任务,覆写任务的执行方法
  • 通过判断阈值,判断该线程是否可以执行
  • 如果不能执行,则将任务继续递归分配,利用fork方法,并行执行
  • 如果是有返回值的,才需要调用join方法,汇集数据。

RecursiveTask

这是一个有返回值的返回值的子类

public class RecursiveTaskTest {

    private final static int MAX_THRESHOLD = 3;//设置一个任务处理最大的阈值


    public static void main(String[] args) {
        final ForkJoinPool joinPool = new ForkJoinPool();
        ForkJoinTask<Integer> future = joinPool.submit(new CalculatedRecursiveTask(0, 1000));
        try {
            Integer integer = future.get();
            System.out.println("执行结果:" + integer);
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }


    private static class CalculatedRecursiveTask extends RecursiveTask<Integer> {


        private final int start;//任务开始的上标
        private final int end;//任务开始的下标

        private CalculatedRecursiveTask(int start, int end) {
            this.start = start;
            this.end = end;
        }

        @Override
        protected Integer compute() {
            if (end - start <= MAX_THRESHOLD) {//如果自己能处理,就自己计算
                return IntStream.rangeClosed(start, end).sum();
            } else {//自己处理不了,拆分任务
                int middle = (end + start) / 2;
                CalculatedRecursiveTask leftTask = new CalculatedRecursiveTask(start, middle);
                CalculatedRecursiveTask rightTask = new CalculatedRecursiveTask(middle + 1, end);

                leftTask.fork();
                rightTask.fork();

                return leftTask.join() + rightTask.join();
            }
        }
    }
}

结果:

执行结果:500500

RecursiveAction

这是一个没有返回值的返回值的子类

public class ForkJoinRecursiveAction {

    private final static int MAX_THRESHOLD = 3;//设置一个任务处理最大的阈值

    private final static AtomicInteger SUM = new AtomicInteger();


    public static void main(String[] args) throws InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();

        forkJoinPool.submit(new CalculateRecursiveAction(0,1000));

        forkJoinPool.awaitTermination(10, TimeUnit.SECONDS);

        System.out.println("执行结果为:" + SUM);
    }

    private static class CalculateRecursiveAction extends RecursiveAction{

        private final int start;
        private final int end;

        private CalculateRecursiveAction(int start, int end) {
            this.start = start;
            this.end = end;
        }

        @Override
        protected void compute() {
            if ((end-start)<=MAX_THRESHOLD){
                SUM.addAndGet(IntStream.rangeClosed(start,end).sum());
            }else {
                int middle = (start+end)/2;
                CalculateRecursiveAction leftAction = new CalculateRecursiveAction(start,middle);
                CalculateRecursiveAction rightAction = new CalculateRecursiveAction(middle+1,end);
                leftAction.fork();
                rightAction.fork();
            }
        }
    }
}

结果:

执行结果:500500

支持Runnable和Callable

public void execute(Runnable task) {
    if (task == null)
        throw new NullPointerException();
    ForkJoinTask<?> job;
    if (task instanceof ForkJoinTask<?>) // avoid re-wrap
        job = (ForkJoinTask<?>) task;
    else
        job = new ForkJoinTask.RunnableExecuteAction(task);
    externalPush(job);
}



public <T> ForkJoinTask<T> submit(Callable<T> task) {
    ForkJoinTask<T> job = new ForkJoinTask.AdaptedCallable<T>(task);
    externalPush(job);
    return job;
}

从源码上可以看到,即使不继承两个子类,也可以提交任务

总结

  • Fork/Join是一种基于“分治”的算法:通过分解任务,并行执行,最后合并结果得到最终结果。

  • ForkJoinPool线程池可以把一个大任务分拆成小任务并行执行,任务类必须继承自RecursiveTask或RecursiveAction。

  • 使用Fork/Join模式可以进行并行计算以提高效率。

  • Java标准库提供的java.util.Arrays.parallelSort(array)可以进行并行排序,它的原理就是内部通过Fork/Join对大数组分拆进行并行排序,在多核CPU上就可以大大提高排序的速度。