前提概述
Java 7開始引入了一種新的Fork/Join線程池,它可以執行一種特殊的任務:把一個大任務拆成多個小任務並行執行。
我們舉個例子:如果要計算一個超大數組的和,最簡單的做法是用一個循環在一個線程內完成:
算法原理介紹
相信大家此前或多或少有了解到ForkJoin框架,ForkJoin框架其實就是一個線程池ExecutorService的實現,通過工作竊取(work-stealing)算法,獲取其他線程中未完成的任務來執行。可以充分利用機器的多處理器優勢,利用空閑的線程去並行快速完成一個可拆分為小任務的大任務,類似於分治算法。
實現達成目標
-
ForkJoin的目標,就是利用所有可用的處理能力來提高程序的響應和性能。本文將介紹ForkJoin框架,依次介紹基礎特性、案例使用、源碼剖析和實現亮點。
-
java.util.concurrent.ForkJoinPool由Java大師Doug Lea主持編寫,它可以將一個大的任務拆分成多個子任務進行並行處理,最后將子任務結果合並成最后的計算結果,並進行輸出。
基本使用
入門例子,用Fork/Join框架使用示例,在這個示例中我們計算了1-5000累加后的值:
public class TestForkAndJoinPlus {
private static final Integer MAX = 400;
static class WorkTask extends RecursiveTask<Integer> {
// 子任務開始計算的值
private Integer startValue;
// 子任務結束計算的值
private Integer endValue;
public WorkTask(Integer startValue , Integer endValue) {
this.startValue = startValue;
this.endValue = endValue;
}
@Override
protected Integer compute() {
// 如果小於最小分片閾值,則說明要進行相關的數據操作
// 可以正式進行累加計算了
if(endValue - startValue < MAX) {
System.out.println("開始計算的部分:startValue = " + startValue + ";endValue = " + endValue);
Integer totalValue = 0;
for(int index = this.startValue ; index <= this.endValue ; index++) {
totalValue += index;
}
return totalValue;
}
// 否則再進行任務拆分,拆分成兩個任務
else {
// 因為采用二分法,拆分,所以進行1/2切分數據量
WorkTask subTask1 = new WorkTask(startValue, (startValue + endValue) / 2);
subTask1.fork();//進行拆分機制控制
WorkTask subTask2 = new WorkTask((startValue + endValue) / 2 + 1 , endValue);
subTask2.fork();
return subTask1.join() + subTask2.join();
}
}
}
public static void main(String[] args) {
// 這是Fork/Join框架的線程池
ForkJoinPool pool = new ForkJoinPool();
ForkJoinTask<Integer> taskFuture = pool.submit(new MyForkJoinTask(1,1001));
try {
Integer result = taskFuture.get();
System.out.println("result = " + result);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace(System.out);
}
}
}
對此我封裝了一個框架集合,基於JDK1.8+中的Fork/Join框架實現,參考的Fork/Join框架主要源代碼也基於JDK1.8+。
WorkTaskCallable實現抽象模型層次操作轉換
@Accessors(chain = true)
public class WorkTaskCallable<T> extends RecursiveTask<T> {
/**
* 斷言操作控制
*/
@Getter
private Predicate<T> predicate;
/**
* 執行參數化分割條件
*/
@Getter
private T splitParam;
/**
* 操作拆分方法操作機制
*/
@Getter
private Function<Object,Object[]> splitFunction;
/**
* 操作合並方法操作機制
*/
@Getter
private BiFunction<Object,Object,T> mergeFunction;
/**
* 操作處理機制
*/
@Setter
@Getter
private Function<T,T> processHandler;
/**
* 構造器是否進行分割操作
* @param predicate 判斷是否進行下一步分割的條件關系
* @param splitParam 分割參數
* @param splitFunction 分割方法
* @param mergeFunction 合並數據操作
*/
public WorkTaskCallable(Predicate predicate,T splitParam,Function<Object,Object[]> splitFunction,BiFunction<Object,Object,T> mergeFunction,Function<T,T> processHandler){
this.predicate = predicate;
this.splitParam = splitParam;
this.splitFunction = splitFunction;
this.mergeFunction = mergeFunction;
this.processHandler = processHandler;
}
/**
* 實際執行調用操作機制
* @return
*/
@Override
protected T compute() {
if(predicate.test(splitParam)){
Object[] result = splitFunction.apply(splitParam);
WorkTaskCallable workTaskCallable1 = new WorkTaskCallable(predicate,result[0],splitFunction,mergeFunction,processHandler);
workTaskCallable1.fork();
WorkTaskCallable workTaskCallable2 = new WorkTaskCallable(predicate,result[1],splitFunction,mergeFunction,processHandler);
workTaskCallable2.fork();
return mergeFunction.apply(workTaskCallable1.join(),workTaskCallable2.join());
}else{
return processHandler.apply(splitParam);
}
}
}
ArrayListWorkTaskCallable實現List集合層次操作轉換
/**
* @project-name:wiz-shrding-framework
* @package-name:com.wiz.sharding.framework.boot.common.thread.forkjoin
* @author:LiBo/Alex
* @create-date:2021-09-09 17:26
* @copyright:libo-alex4java
* @email:liboware@gmail.com
* @description:
*/
public class ArrayListWorkTaskCallable extends WorkTaskCallable<List>{
static Predicate<List> predicateFunction = param->param.size() > 3;
static Function<List,List[]> splitFunction = (param)-> {
if(predicateFunction.test(param)){
return new List[]{param.subList(0,param.size()/ 2),param.subList(param.size()/2,param.size())};
}else{
return new List[]{param.subList(0,param.size()+1),Lists.newArrayList()};
}
};
static BiFunction<List,List,List> mergeFunction = (param1,param2)->{
List datalist = Lists.newArrayList();
datalist.addAll(param2);
datalist.addAll(param1);
return datalist;
};
/**
* 構造器是否進行分割操作
* @param predicate 判斷是否進行下一步分割的條件關系
* @param splitParam 分割參數
* @param splitFunction 分割方法
* @param mergeFunction 合並數據操作
*/
public ArrayListWorkTaskCallable(Predicate<List> predicate, List splitParam, Function splitFunction, BiFunction mergeFunction,
Function<List,List> processHandler) {
super(predicate, splitParam, splitFunction, mergeFunction,processHandler);
}
public ArrayListWorkTaskCallable(List splitParam, Function splitFunction, BiFunction mergeFunction,
Function<List,List> processHandler) {
super(predicateFunction, splitParam, splitFunction, mergeFunction,processHandler);
}
public ArrayListWorkTaskCallable(Predicate<List> predicate,List splitParam,Function<List,List> processHandler) {
this(predicate, splitParam, splitFunction, mergeFunction,processHandler);
}
public ArrayListWorkTaskCallable(List splitParam,Function<List,List> processHandler) {
this(predicateFunction, splitParam, splitFunction, mergeFunction,processHandler);
}
public static void main(String[] args){
List dataList = Lists.newArrayList(0,1,2,3,4,5,6,7,8,9);
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
ForkJoinTask<List> forkJoinResult = forkJoinPool.submit(new ArrayListWorkTaskCallable(dataList,param->Lists.newArrayList(param.size())));
try {
System.out.println(forkJoinResult.get());
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
ForkJoin代碼分析
ForkJoinPool構造函數
/**
* Creates a {@code ForkJoinPool} with parallelism equal to {@link
* java.lang.Runtime#availableProcessors}, using the {@linkplain
* #defaultForkJoinWorkerThreadFactory default thread factory},
* no UncaughtExceptionHandler, and non-async LIFO processing mode.
*
* @throws SecurityException if a security manager exists and
* the caller is not permitted to modify threads
* because it does not hold {@link
* java.lang.RuntimePermission}{@code ("modifyThread")}
*/
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
/**
* Creates a {@code ForkJoinPool} with the indicated parallelism
* level, the {@linkplain
* #defaultForkJoinWorkerThreadFactory default thread factory},
* no UncaughtExceptionHandler, and non-async LIFO processing mode.
*
* @param parallelism the parallelism level
* @throws IllegalArgumentException if parallelism less than or
* equal to zero, or greater than implementation limit
* @throws SecurityException if a security manager exists and
* the caller is not permitted to modify threads
* because it does not hold {@link
* java.lang.RuntimePermission}{@code ("modifyThread")}
*/
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
/**
* Creates a {@code ForkJoinPool} with the given parameters.
*
* @param parallelism the parallelism level. For default value,
* use {@link java.lang.Runtime#availableProcessors}.
* @param factory the factory for creating new threads. For default value,
* use {@link #defaultForkJoinWorkerThreadFactory}.
* @param handler the handler for internal worker threads that
* terminate due to unrecoverable errors encountered while executing
* tasks. For default value, use {@code null}.
* @param asyncMode if true,
* establishes local first-in-first-out scheduling mode for forked
* tasks that are never joined. This mode may be more appropriate
* than default locally stack-based mode in applications in which
* worker threads only process event-style asynchronous tasks.
* For default value, use {@code false}.
* @throws IllegalArgumentException if parallelism less than or
* equal to zero, or greater than implementation limit
* @throws NullPointerException if the factory is null
* @throws SecurityException if a security manager exists and
* the caller is not permitted to modify threads
* because it does not hold {@link
* java.lang.RuntimePermission}{@code ("modifyThread")}
*/
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism),
checkFactory(factory),
handler,
(asyncMode ? FIFO_QUEUE : LIFO_QUEUE),
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission();
}
/**
* Creates a {@code ForkJoinPool} with the given parameters, without
* any security checks or parameter validation. Invoked directly by
* makeCommonPool.
*/
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.mode = (short)mode;
this.parallelism = (short)parallelism;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
-
parallelism:可並行級別,Fork/Join框架將依據這個並行級別的設定,決定框架內並行執行的線程數量。並行的每一個任務都會有一個線程進行處理,但是千萬不要將這個屬性理解成Fork/Join框架中最多存在的線程數量。
-
factory:當Fork/Join框架創建一個新的線程時,同樣會用到線程創建工廠。只不過這個線程工廠不再需要實現ThreadFactory接口,而是需要實現ForkJoinWorkerThreadFactory接口。后者是一個函數式接口,只需要實現一個名叫newThread的方法。
在Fork/Join框架中有一個默認的ForkJoinWorkerThreadFactory接口實現:DefaultForkJoinWorkerThreadFactory。
-
handler:異常捕獲處理器。當執行的任務中出現異常,並從任務中被拋出時,就會被handler捕獲。
-
asyncMode:這個參數也非常重要,從字面意思來看是指的異步模式,它並不是說Fork/Join框架是采用同步模式還是采用異步模式工作。Fork/Join框架中為每一個獨立工作的線程准備了對應的待執行任務隊列,這個任務隊列是使用數組進行組合的雙向隊列。即是說存在於隊列中的待執行任務,即可以使用先進先出的工作模式,也可以使用后進先出的工作模式。
-
先進先出
-
后進先出
-
當asyncMode設置為true的時候,隊列采用先進先出方式工作;反之則是采用后進先出的方式工作,該值默認為false
- asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
-
需要注意點
-
ForkJoinPool 一個構造函數只帶有parallelism參數,既是可以設定Fork/Join框架的最大並行任務數量;另一個構造函數則不帶有任何參數,對於最大並行任務數量也只是一個默認值——當前操作系統可以使用的CPU內核數量(Runtime.getRuntime().availableProcessors())。實際上ForkJoinPool還有一個私有的、原生構造函數,之上提到的三個構造函數都是對這個私有的、原生構造函數的調用。
-
如果你對Fork/Join框架沒有特定的執行要求,可以直接使用不帶有任何參數的構造函數。也就是說推薦基於當前操作系統可以使用的CPU內核數作為Fork/Join框架內最大並行任務數量,這樣可以保證CPU在處理並行任務時,盡量少發生任務線程間的運行狀態切換(實際上單個CPU內核上的線程間狀態切換基本上無法避免,因為操作系統同時運行多個線程和多個進程)。
-
從上面的的類關系圖可以看出來,ForkJoin框架的核心是ForkJoinPool類,基於AbstractExecutorService擴展(@sun.misc.Contended注解)。
-
ForkJoinPool中維護了一個隊列數組WorkQueue[],每個WorkQueue維護一個ForkJoinTask數組和當前工作線程。ForkJoinPool實現了工作竊取(work-stealing)算法並執行ForkJoinTask。
ForkJoinPool類的屬性介紹
-
ADD_WORKER: 100000000000000000000000000000000000000000000000 -> 1000 0000 0000 0000,用來配合ctl在控制線程數量時使用
-
ctl: 控制ForkJoinPool創建線程數量,(ctl & ADD_WORKER) != 0L 時創建線程,也就是當ctl的第16位不為0時,可以繼續創建線程
-
defaultForkJoinWorkerThreadFactory: 默認線程工廠,默認實現是DefaultForkJoinWorkerThreadFactory
-
runState: 全局鎖控制,全局運行狀態
-
workQueues: 工作隊列數組WorkQueue[]
-
config: 記錄並行數量和ForkJoinPool的模式(異步或同步)
WorkQueue類
-
qlock: 並發控制,put任務時的鎖控制
-
array: 任務數組ForkJoinTask<?>[]
-
pool: ForkJoinPool,所有線程和WorkQueue共享,用於工作竊取、任務狀態和工作狀態同步
-
base: array數組中取任務的下標
-
top: array數組中放置任務的下標
-
owner: 所屬線程,ForkJoin框架中,只有一個WorkQueue是沒有owner的,其他的均有具體線程owner
ForkJoinTask是能夠在ForkJoinPool中執行的任務抽象類,父類是Future,具體實現類有很多,這里主要關注RecursiveAction和RecursiveTask。
- RecursiveAction是沒有返回結果的任務
- RecursiveTask是需要返回結果的任務。
ForkJoinTask類屬性的介紹
status: 任務的狀態,對其他工作線程和pool可見,運行正常則status為負數,異常情況為正數。
ForkJoinTask功能介紹
-
ForkJoinTask任務是一種能在Fork/Join框架中運行的特定任務,也只有這種類型的任務可以在Fork/Join框架中被拆分運行和合並運行。
-
ForkJoinWorkerThread線程是一種在Fork/Join框架中運行的特性線程,它除了具有普通線程的特性外,最主要的特點是每一個ForkJoinWorkerThread線程都具有一個獨立的任務等待隊列(work queue),這個任務隊列用於存儲在本線程中被拆分的若干子任務。
只需要實現其compute()方法,在compute()中做最小任務控制,任務分解(fork)和結果合並(join)。
ForkJoinPool中執行的默認線程是ForkJoinWorkerThread,由默認工廠產生,可以自己重寫要實現的工作線程。同時會將ForkJoinPool引用放在每個工作線程中,供工作竊取時使用。
ForkJoinWorkerThread類屬性介紹
- pool: ForkJoinPool,所有線程和WorkQueue共享,用於工作竊取、任務狀態和工作狀態同步。
- workQueue: 當前線程的任務隊列,與WorkQueue的owner呼應。
簡易執行圖
實際上Fork/Join框架的內部工作過程要比這張圖復雜得多,例如如何決定某一個recursive task是使用哪條線程進行運行;再例如如何決定當一個任務/子任務提交到Fork/Join框架內部后,是創建一個新的線程去運行還是讓它進行隊列等待。
邏輯模型圖(盜一張圖:)
()
fork方法和join方法
Fork/Join框架中提供的fork方法和join方法,可以說是該框架中提供的最重要的兩個方法,它們和parallelism“可並行任務數量”配合工作。
Fork方法介紹
- Fork就是一個不斷分枝的過程,在當前任務的基礎上長出n多個子任務,他將新創建的子任務放入當前線程的work queue隊列中,Fork/Join框架將根據當前正在並發執行ForkJoinTask任務的ForkJoinWorkerThread線程狀態,決定是讓這個任務在隊列中等待,還是創建一個新的ForkJoinWorkerThread線程運行它,又或者是喚起其它正在等待任務的ForkJoinWorkerThread線程運行它。
當一個ForkJoinTask任務調用fork()方法時,當前線程會把這個任務放入到queue數組的queueTop位置,然后執行以下兩句代碼:
if ((s -= queueBase) <= 2)
pool.signalWork();
else if (s == m)
growQueue();
當調用signalWork()方法。signalWork()方法做了兩件事:1、喚配當前線程;2、當沒有活動線程時或者線程數較少時,添加新的線程。
Join方法介紹
Join是一個不斷等待,獲取任務執行結果的過程。
private int doJoin() {
Thread t; ForkJoinWorkerThread w; int s; boolean completed;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
if ((s = status) < 0)
return s;
if ((w = (ForkJoinWorkerThread)t).unpushTask(this)) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
return setCompletion(NORMAL);
}
return w.joinTask(this);
}
else
return externalAwaitDone();
}
- 第4行,(s=status)<0表示這個任務被執行完,直接返回執行結果狀態,上層捕獲到狀態后,決定是要獲取結果還是進行錯誤處理;
- 第6行,從queue中取出這個任務來執行,如果執行完了,就設置狀態為NORMAL;
- 前面unpushTask()方法在隊列中沒有這個任務時會返回false,15行調用joinTask等待這個任務完成。
- 由於ForkJoinPool中有一個數組叫submissionQueue,通過submit方法調用而且非ForkJoinTask這種任務會被放到這個隊列中。這種任務有可能被非ForkJoinWorkerThread線程執行,第18行表示如果是這種任務,等待它執行完成。
下面來看joinTask方法
final int joinTask(ForkJoinTask<?> joinMe) {
ForkJoinTask<?> prevJoin = currentJoin;
currentJoin = joinMe;
for (int s, retries = MAX_HELP;;) {
if ((s = joinMe.status) < 0) {
currentJoin = prevJoin;
return s;
}
if (retries > 0) {
if (queueTop != queueBase) {
if (!localHelpJoinTask(joinMe))
retries = 0; // cannot help
}
else if (retries == MAX_HELP >>> 1) {
--retries; // check uncommon case
if (tryDeqAndExec(joinMe) >= 0)
Thread.yield(); // for politeness
}
else
retries = helpJoinTask(joinMe) ? MAX_HELP : retries - 1;
}
else {
retries = MAX_HELP; // restart if not done
pool.tryAwaitJoin(joinMe);
}
}
}
- (1)這里有個常量MAX_HELP=16,表示幫助join的次數。第11行,queueTop!=queueBase表示本地隊列中有任務,如果這個任務剛好在隊首,則嘗試自己執行;否則返回false。這時retries被設置為0,表示不能幫助,因為自已隊列不為空,自己並不空閑。在下一次循環就會進入第24行,等待這個任務執行完成。
- (2)第20行helpJoinTask()方法返回false時,retries-1,連續8次都沒有幫到忙,就會進入第14行,調用yield讓權等待。沒辦法人口太差,想做點好事都不行,只有停下來休息一下。
- (3)當執行到第20行,表示自己隊列為空,可以去幫助這個任務了,下面來看是怎么幫助的?
outer:for (ForkJoinWorkerThread thread = this;;) {
// Try to find v, the stealer of task, by first using hint
ForkJoinWorkerThread v = ws[thread.stealHint & m];
if (v == null || v.currentSteal != task) {
for (int j = 0; ;) { // search array
if ((v = ws[j]) != null && v.currentSteal == task) {
thread.stealHint = j;
break; // save hint for next time
}
if (++j > m)
break outer; // can't find stealer
}
}
// Try to help v, using specialized form of deqTask
for (;;) {
ForkJoinTask<?>[] q; int b, i;
if (joinMe.status < 0)
break outer;
if ((b = v.queueBase) == v.queueTop ||
(q = v.queue) == null ||
(i = (q.length-1) & b) < 0)
break; // empty
long u = (i << ASHIFT) + ABASE;
ForkJoinTask<?> t = q[i];
if (task.status < 0)
break outer; // stale
if (t != null && v.queueBase == b &&
UNSAFE.compareAndSwapObject(q, u, t, null)) {
v.queueBase = b + 1;
v.stealHint = poolIndex;
ForkJoinTask<?> ps = currentSteal;
currentSteal = t;
t.doExec();
currentSteal = ps;
helped = true;
}
}
// Try to descend to find v's stealer
ForkJoinTask<?> next = v.currentJoin;
if (--levels > 0 && task.status >= 0 &&
next != null && next != task) {
task = next;
thread = v;
}
}
- (1)通過查看stealHint這個字段的注釋可以知道,它表示最近一次誰來偷過我的queue中的任務。因此通過stealHint並不能找到當前任務被誰偷了?所以第4行v.currentSteal != task完全可能。這時還有一個辦法找到這個任務被誰偷了,看看currentSteal這個字段的注釋表示最近偷的哪個任務。這里掃描所有偷來的任務與當前任務比較,如果相等,就是這個線程偷的。如果這兩種方法都不能找到小偷,只能等待了。
- (2)當找到了小偷后,以其人之身還之其人之道,從小偷那里偷任務過來,相當於你和小偷共同執行你的任務,會加速你的任務完成。
- (3)小偷也是爺,如果小偷也在等待一個任務完成,權利反轉(小偷等待的這個任務做為當前任務,小偷扮演當事人角色把前面的流程走一遍),這是一個遞歸的過程。