java-forkjoin框架的使用


ForkJoin是Java7提供的原生多線程並行處理框架,其基本思想是將大任務分割成小任務,最后將小任務聚合起來得到結果。fork是分解的意思, join是收集的意思. 它非常類似於HADOOP提供的MapReduce框架,只是MapReduce的任務可以針對集群內的所有計算節點,可以充分利用集群的能力完成計算任務。ForkJoin更加類似於單機版的MapReduce。

 

在fork/join框架中,若某個子問題由於等待另一個子問題的完成而無法繼續執行。那么處理該子問題的線程會主動尋找其他尚未運行完成的子問題來執行。這種方式減少了線程的等待時間,提高了性能。子問題中應該避免使用synchronized關鍵詞或其他方式方式的同步。也不應該是一阻塞IO或過多的訪問共享變量。在理想情況下,每個子問題的實現中都應該只進行CPU相關的計算,並且只適用每個問題的內部對象。唯一的同步應該只發生在子問題和創建它的父問題之間。

 

Fork/Join使用兩個類完成以上兩件事情:

 
· ForkJoinTask: 我們要使用ForkJoin框架,必須首先創建一個ForkJoin任務。它提供在任務中執行fork()和join的操作機制,ForkJoinTask實現了Future接口,可以按照Future接口的方式來使用。在ForkJoinTask類中之重要的兩個方法fork和join。fork方法用以一部方式啟動任務的執行,join方法則等待任務完成並返回指向結果。在創建自己的任務是,最好不要直接繼承自ForkJoinTask類,而要繼承自ForkJoinTask類的子類RecurisiveTask或RecurisiveAction類
    1. RecursiveAction,用於沒有返回結果的任務
    2. RecursiveTask,用於有返回值的任務

 源碼推薦查詢 jdk8的

 

 

 · ForkJoinPool:task要通過ForkJoinPool來執行,分割的子任務也會添加到當前工作線程的雙端隊列中,進入隊列的頭部。當一個工作線程中沒有任務時,會從其他工作線程的隊列尾部獲取一個任務。

 2個構造方法

ForkJoinPool(int parallelism)  創建一個包含parallelism個並行線程的ForkJoinPool。

ForkJoinPool()  以Runtime.availableProcessors()方法的返回值作為parallelism參數來創建ForkJoinPool。

3種方式啟動

異步執行          execute(ForkJoinTask)         ForkJoinTask.fork
等待獲取結果      invoke(ForkJoinTask)          ForkJoinTask.invoke
執行,獲取Future    submit(ForkJoinTask)        ForkJoinTask.fork(ForkJoinTask are Futures)        

異常處理: 

ForkJoinTask在執行的時候可能會拋出異常,但是沒辦法在主線程里直接捕獲異常,所以ForkJoinTask提供了isCompletedAbnormally()方法來檢查任務是否已經拋出異常或已經被取消了,並且可以通過ForkJoinTask的getException方法獲取異常. 

getException方法返回Throwable對象,如果任務被取消了則返回CancellationException。如果任務沒有完成或者沒有拋出異常則返回null。

if(task.isCompletedAbnormally()) {
    System.out.println(task.getException());
}

然后, 代碼展示

import java.util.concurrent.ForkJoinPool
import java.util.concurrent.ForkJoinTask
import java.util.concurrent.RecursiveTask
/**
 * fork
 * 對一個大數組進行並行求和的RecursiveTask
 *
 * 大任務可以拆成小任務,小任務還可以繼續拆成更小的任務,最后把任務的結果匯總合並,得到最終結果,這種模型就是Fork/Join模型。
 Java7引入了Fork/Join框架,我們通過RecursiveTask這個類就可以方便地實現Fork/Join模式。
 * Created by wenbronk on 2017/7/13.
 */
class ForkJoinTest extends RecursiveTask<Long> {

    static final int THRESHOLD = 100
    long[] array
    int start
    int end

    ForkJoinTest(long[] array, int start, int end) {
        this.start = start
        this.end = end
        this.array = array
    }
    @Override
    protected Long compute() {

        if (end - start < THRESHOLD) {
            long sum = 0
            for (int i = start; i < end; i++) {
                sum += array[i]
            }
            try {
                Thread.sleep(100)
            } catch (Exception e) {
                e.printStackTrace()
            }
            println String.format('compute %d %d = %d', start, end, sum)
        }

        // 對於大任務, 分多線程執行
        int middle = (end + start) / 2
        println String.format('split %d %d => %d %d, %d %d', start, end, start, middle, middle, end)

        def subtask1 = new ForkJoinTest(this.array, start, middle);
        def subtask2 = new ForkJoinTest(this.array, middle, end);
        invokeAll(subtask1, subtask2)

        Long subresult1 = subtask1.join()
        Long subresult2 = subtask2.join()

        Long result = subresult1 + subresult2

        System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
        return result
    }

    public static void main(String[] args) throws Exception {
        // 創建隨機數組成的數組:
        long[] array = new long[400];
//        fillRandom(array);
        // fork/join task:
        ForkJoinPool fjp = new ForkJoinPool(4); // 最大並發數4
        ForkJoinTask<Long> task = new ForkJoinTest(array, 0, array.length);
        long startTime = System.currentTimeMillis();
        Long result = fjp.invoke(task);
        long endTime = System.currentTimeMillis();
        System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
    }
}

 java代碼的實現

package com.wenbronk.test;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

/**
 * forkjoin的簡單易用
 * Created by wenbronk on 2017/7/26.
 */
public class CountTask extends RecursiveTask<Integer>{
    private volatile static int count = 0;
    private int start;
    private int end;

    public CountTask(int start, int end) {
        this.start = start;
        this.end = end;
    }

    public static final int threadhold = 2;

    @Override
    protected Integer compute() {
        int sum = 0;
        System.out.println("開啟了一條線程單獨干: " + count++);
        // 如果任務足夠小, 就直接執行
        boolean canCompute = (end - start) <= threadhold;
        if (canCompute) {
            for (int i = start; i <= end; i++) {
                sum += i;
            }
        }else {
            //任務大於閾值, 分裂為2個任務
            int middle = (start + end) / 2;
            CountTask countTask1 = new CountTask(start, middle);
            CountTask countTask2 = new CountTask(middle + 1, end);

            // 開啟線程
//            countTask1.fork();
//            countTask2.fork();
            invokeAll(countTask1, countTask2);

            Integer join1 = countTask1.join();
            Integer join2 = countTask2.join();

            // 結果合並
            sum = join1 + join2;
        }
        return sum;
    }


    // 測試
    public static void main(String[] args) throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();

        CountTask countTask = new CountTask(1, 100);
        ForkJoinTask<Integer> result = forkJoinPool.submit(countTask);
        System.out.println(result.get());
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM