04-23.eri-test 带有ThreadPool的Java 8并行流

\n

When executing a parallel stream, it runs in the Common Fork Join Pool (ForkJoinPool.commonPool()), shared by all other parallel streams.
\nSometimes we want to execute code in parallel on a separate dedicated thread pool, constructed with a specific number of threads. When using, for example, myCollection.parallelStream() it doesn\'t give us a convenient way to do that.
\nI wrote a small handy utility (ThreadExecutor class) that can be used for that purpose.
\nIn the following example, I will demonstrate simple usage of the ThreadExecutor utility to fill a long array with calculated numbers, each number is calculated in a thread on a Fork Join Pool (not the common pool).
\nThe creation of the thread pool is done by the utility. We control the number of threads in the pool (int parallelism), the name of the threads in the pool (useful when investigating threads dump), and optionally a timeout limit.
\nI tested it with JUnit 5 which provides a nice way to time the test methods (see TimingExtension).

\n\n

All source code is available in GitHub at:
\nhttps://github.com/igalhaddad/thread-executor

\n

\n \n \n ThreadExecutor Utility class\n

\n\n\n
import com.google.common.base.Throwables;\nimport com.google.common.util.concurrent.ExecutionError;\nimport com.google.common.util.concurrent.UncheckedExecutionException;\nimport com.google.common.util.concurrent.UncheckedTimeoutException;\n\nimport java.time.Duration;\nimport java.util.concurrent.*;\nimport java.util.function.Consumer;\nimport java.util.function.Function;\n\npublic class ThreadExecutor {\n    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, Function<T, R> parallelStream) {\n        return execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);\n    }\n\n    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Function<T, R> parallelStream) {\n        if (timeout < 0)\n            throw new IllegalArgumentException("Invalid timeout " + timeout);\n        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)\n        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);\n        Future<R> future = threadPool.submit(() -> parallelStream.apply(source));\n        try {\n            return timeout == 0 ? future.get() : future.get(timeout, unit);\n        } catch (ExecutionException e) {\n            future.cancel(true);\n            threadPool.shutdownNow();\n            Throwable cause = e.getCause();\n            if (cause instanceof Error)\n                throw new ExecutionError((Error) cause);\n            throw new UncheckedExecutionException(cause);\n        } catch (TimeoutException e) {\n            future.cancel(true);\n            threadPool.shutdownNow();\n            throw new UncheckedTimeoutException(e);\n        } catch (Throwable t) {\n            future.cancel(true);\n            threadPool.shutdownNow();\n            Throwables.throwIfUnchecked(t);\n            throw new RuntimeException(t);\n        } finally {\n            threadPool.shutdown();\n        }\n    }\n\n    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, Consumer<T> parallelStream) {\n        execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);\n    }\n\n    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Consumer<T> parallelStream) {\n        if (timeout < 0)\n            throw new IllegalArgumentException("Invalid timeout " + timeout);\n        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)\n        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);\n        CompletableFuture<Void> future = null;\n        try {\n            Runnable task = () -> parallelStream.accept(source);\n            if (timeout == 0) {\n                future = CompletableFuture.runAsync(task, threadPool);\n                future.get();\n                threadPool.shutdown();\n            } else {\n                threadPool.execute(task);\n                threadPool.shutdown();\n                if (!threadPool.awaitTermination(timeout, unit))\n                    throw new TimeoutException("Timed out after: " + Duration.of(timeout, unit.toChronoUnit()));\n            }\n        } catch (TimeoutException e) {\n            threadPool.shutdownNow();\n            throw new UncheckedTimeoutException(e);\n        } catch (ExecutionException e) {\n            future.cancel(true);\n            threadPool.shutdownNow();\n            Throwable cause = e.getCause();\n            if (cause instanceof Error)\n                throw new ExecutionError((Error) cause);\n            throw new UncheckedExecutionException(cause);\n        } catch (Throwable t) {\n            threadPool.shutdownNow();\n            Throwables.throwIfUnchecked(t);\n            throw new RuntimeException(t);\n        }\n    }\n}\n
\n\n
\n \n \n NamedForkJoinWorkerThreadFactory class\n
\n\n\n
import java.util.concurrent.ForkJoinPool;\nimport java.util.concurrent.ForkJoinWorkerThread;\nimport java.util.concurrent.atomic.AtomicInteger;\n\npublic class NamedForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {\n    private AtomicInteger counter = new AtomicInteger(0);\n    private final String name;\n    private final boolean daemon;\n\n    public NamedForkJoinWorkerThreadFactory(String name, boolean daemon) {\n        this.name = name;\n        this.daemon = daemon;\n    }\n\n    public NamedForkJoinWorkerThreadFactory(String name) {\n        this(name, false);\n    }\n\n    @Override\n    public ForkJoinWorkerThread newThread(ForkJoinPool pool) {\n        ForkJoinWorkerThread t = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);\n        t.setName(name + counter.incrementAndGet());\n        t.setDaemon(daemon);\n        return t;\n    }\n}\n
\n\n

\n \n \n ThreadExecutorTests JUnit class\n

\n\n\n
import static org.junit.jupiter.api.Assertions.*;\n\nimport com.github.igalhaddad.threadexecutor.timing.TimingExtension;\nimport org.junit.jupiter.api.*;\nimport org.junit.jupiter.api.MethodOrderer.OrderAnnotation;\nimport org.junit.jupiter.api.extension.ExtendWith;\n\nimport java.util.ArrayList;\nimport java.util.Arrays;\nimport java.util.List;\nimport java.util.logging.Logger;\nimport java.util.stream.Collectors;\n\n@ExtendWith(TimingExtension.class)\n@TestMethodOrder(OrderAnnotation.class)\n@DisplayName("Test ThreadExecutor utility")\npublic class ThreadExecutorTests {\n    private static final Logger logger = Logger.getLogger(ThreadExecutorTests.class.getName());\n    private static final int SEQUENCE_LENGTH = 1000000;\n\n    private static List<long[]> fibonacciSequences = new ArrayList<>();\n    private long[] fibonacciSequence;\n\n    @BeforeAll\n    static void initAll() {\n        logger.info(() -> "Number of available processors: " + Runtime.getRuntime().availableProcessors());\n    }\n\n    @BeforeEach\n    void init() {\n        this.fibonacciSequence = new long[SEQUENCE_LENGTH];\n        fibonacciSequences.add(fibonacciSequence);\n    }\n\n    @AfterEach\n    void tearDown() {\n        int firstX = 10;\n        logger.info(() -> "First " + firstX + " numbers: " + Arrays.stream(this.fibonacciSequence)\n                .limit(firstX)\n                .mapToObj(Long::toString)\n                .collect(Collectors.joining(",", "[", ",...]")));\n        int n = SEQUENCE_LENGTH - 1; // Last number\n        assertFn(n);\n        assertFn(n / 2);\n        assertFn(n / 3);\n        assertFn(n / 5);\n        assertFn(n / 10);\n        assertFn((n / 3) * 2);\n        assertFn((n / 5) * 4);\n    }\n\n    private void assertFn(int n) {\n        assertEquals(fibonacciSequence[n - 1] + fibonacciSequence[n - 2], fibonacciSequence[n]);\n    }\n\n    @AfterAll\n    static void tearDownAll() {\n        long[] fibonacciSequence = fibonacciSequences.iterator().next();\n        for (int i = 1; i < fibonacciSequences.size(); i++) {\n            assertArrayEquals(fibonacciSequence, fibonacciSequences.get(i));\n        }\n    }\n\n    @Test\n    @Order(1)\n    @DisplayName("Calculate Fibonacci sequence sequentially")\n    public void testSequential() {\n        logger.info(() -> "Running sequentially. No parallelism");\n        for (int i = 0; i < fibonacciSequence.length; i++) {\n            fibonacciSequence[i] = Fibonacci.compute(i);\n        }\n    }\n\n    @Test\n    @Order(2)\n    @DisplayName("Calculate Fibonacci sequence concurrently on all processors")\n    public void testParallel1() {\n        testParallel(Runtime.getRuntime().availableProcessors());\n    }\n\n    @Test\n    @Order(3)\n    @DisplayName("Calculate Fibonacci sequence concurrently on half of the processors")\n    public void testParallel2() {\n        testParallel(Math.max(1, Runtime.getRuntime().availableProcessors() / 2));\n    }\n\n    private void testParallel(int parallelism) {\n        logger.info(() -> String.format("Running in parallel on %d processors", parallelism));\n        ThreadExecutor.execute(parallelism, "FibonacciTask", fibonacciSequence,\n                (long[] fibonacciSequence) -> Arrays.parallelSetAll(fibonacciSequence, Fibonacci::compute)\n        );\n    }\n\n    static class Fibonacci {\n        public static long compute(int n) {\n            if (n <= 1)\n                return n;\n            long a = 0, b = 1;\n            long sum = a + b; // for n == 2\n            for (int i = 3; i <= n; i++) {\n                a = sum; // using `a` for temporary storage\n                sum += b;\n                b = a;\n            }\n            return sum;\n        }\n    }\n}\n
\n\n\n

Notice testParallel(int parallelism) method. That method uses ThreadExecutor utility to execute a parallel stream on a separate dedicated thread pool consisting of number of threads as provided, where each thread is named "FibonacciTask" concatenated with a serial number, e.g., "FibonacciTask3".
\nThe named threads come from NamedForkJoinWorkerThreadFactory class.
\nFor example, I paused the testParallel2() test method with a breakpoint in Fibonacci.compute method, and I see 6 threads named "FibonacciTask1-6". Here is one of them:

\n\n
\n

"FibonacciTask3@2715" prio=5 tid=0x22 nid=NA runnable
\n java.lang.Thread.State: RUNNABLE

\n\n
\n
  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$Fibonacci.compute(ThreadExecutorTests.java:103)\n  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$$Lambda$366.1484420181.applyAsLong(Unknown Source:-1)\n  at java.util.Arrays.lambda$parallelSetAll$2(Arrays.java:5408)\n  at java.util.Arrays$$Lambda$367.864455139.accept(Unknown Source:-1)\n  at java.util.stream.ForEachOps$ForEachOp$OfInt.accept(ForEachOps.java:204)\n  at java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Streams.java:104)\n  at java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:699)\n  at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:484)\n  at java.util.stream.ForEachOps$ForEachTask.compute(ForEachOps.java:290)\n  at java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:746)\n  at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:290)\n  at java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1016)\n  at java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1665)\n  at java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1598)\n  at java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:177)\n
\n
\n
\n\n

The testParallel(int parallelism) method execute Arrays.parallelSetAll which is in fact just a simple parallel stream, as implemented in the java source code:
\n

\n\n
    public static void parallelSetAll(long[] array, IntToLongFunction generator) {\n        Objects.requireNonNull(generator);\n        IntStream.range(0, array.length).parallel().forEach(i -> { array[i] = generator.applyAsLong(i); });\n    }\n
\n\n\n\n

\n \n \n Now lets see the test methods timing \xe2\x8f\xb1:\n

\n\n

Test Results
\nAs you can see in the output:

\n\n
  1. \n
  2. \ntestSequential() test method took 148622 ms (No parallelism).\n
  3. \ntestParallel1() test method took 16995 ms (12 processors in parallel).\n
  4. \ntestParallel2() test method took 31152 ms (6 processors in parallel).\n
\n\n

All three test methods did the same task of calculating a Fibonacci sequence in length of 1,000,000 numbers.

\n\n\n
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值