fork/join作為一個並發框架在jdk7的時候就加入到了我們的java並發包java.util.concurrent中,並且在java 8 的lambda並行流中充當着底層框架的角色。這樣一個優秀的框架設計,我自己想了解一下它的底層代碼是如何實現的,所以我嘗試的去閱讀了JDK相關的源碼。下面我打算分享一下閱讀完之后的心得~。
1、fork/join的設計思路
了解一個框架的第一件事,就是先了解別人的設計思路!

fork/join大體的執行過程就如上圖所示,先把一個大任務分解(fork)成許多個獨立的小任務,然后起多線程並行去處理這些小任務。處理完得到結果后再進行合並(join)就得到我們的最終結果。顯而易見的這個框架是借助了現代計算機多核的優勢並行去處理數據。這看起來好像沒有什么特別之處,這個套路很多人都會,並且工作中也會經常運用~。其實fork/join的最特別之處在於它還運用了一種叫work-stealing(工作竊取)的算法,這種算法的設計思路在於把分解出來的小任務放在多個雙端隊列中,而線程在隊列的頭和尾部都可獲取任務。當有線程把當前負責隊列的任務處理完之后,它還可以從那些還沒有處理完的隊列的尾部竊取任務來處理,這連線程的空余時間也充分利用了!。work-stealing原理圖如下:

2、實現fork/join 定義了哪些角色?。
了解設計原理,這僅僅是第一步!要了解別人整個的實現思路, 還需要了解別人為了實現這個框架定義了哪些角色,並了解這些角色的職責范圍是什么的。因為知道誰負責了什么,誰做什么,這樣整個邏輯才能串起來!在JAVA里面角色是以類的形式定義的,而了解類的行為最直接的方式就是看定義的公共方法~。
這里介紹JDK里面與fork/join相關的主要幾個類:
ForkJoinPool:充當fork/join框架里面的管理者,最原始的任務都要交給它才能處理。它負責控制整個fork/join有多少個workerThread,workerThread的創建,激活都是由它來掌控。它還負責workQueue隊列的創建和分配,每當創建一個workerThread,它負責分配相應的workQueue。然后它把接到的活都交給workerThread去處理,它可以說是整個frok/join的容器。
ForkJoinWorkerThread:fork/join里面真正干活的"工人",本質是一個線程。里面有一個ForkJoinPool.WorkQueue的隊列存放着它要干的活,接活之前它要向ForkJoinPool注冊(registerWorker),拿到相應的workQueue。然后就從workQueue里面拿任務出來處理。它是依附於ForkJoinPool而存活,如果ForkJoinPool的銷毀了,它也會跟着結束。
ForkJoinPool.WorkQueue: 雙端隊列就是它,它負責存儲接收的任務。
ForkJoinTask:代表fork/join里面任務類型,我們一般用它的兩個子類RecursiveTask、RecursiveAction。這兩個區別在於RecursiveTask任務是有返回值,RecursiveAction沒有返回值。任務的處理邏輯包括任務的切分都集中在compute()方法里面。
3、fork/join初始化時做了什么
大到一個系統,小到一個框架,初始化工作往往是體現邏輯的一個重要地方!因為這是開始的地方,后面的邏輯會有依賴!所以把初始化看明白了,后面很多邏輯就容易理解多了。
下面上一段代碼,(ps:這段代碼是在網上找到的,並做了一小部分的修改)
public class CountTask extends RecursiveTask<Integer> {
private static final int THRESHOLD = 2; //閥值
private int start;
private int end;
public CountTask(int start,int end){
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int sum = 0;
boolean canCompute = (end - start) <= THRESHOLD;
if(canCompute){
for(int i = start; i <= end; i++){
sum += i;
}
}else{
int middle = (start + end) / 2;
CountTask leftTask = new CountTask(start,middle);
CountTask rightTask = new CountTask(middle + 1,end);
//執行子任務
leftTask.fork();
rightTask.fork();
//等待子任務執行完,並得到其結果
Integer rightResult = rightTask.join();
Integer leftResult = leftTask.join();
//合並子任務
sum = leftResult + rightResult;
}
return sum;
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
ForkJoinPool forkJoinPool = new ForkJoinPool();
CountTask countTask = new CountTask(1,200);
ForkJoinTask<Integer> forkJoinTask = forkJoinPool.submit(countTask);
System.out.println(forkJoinTask.get());
}
}
代碼的執行過程解釋起來也是很簡單就是把[1,200],分成[1,100],[101,200],然后再對每個部分進行一個遞歸分解最終分解成[1,2],[3,4],[5,6].....[199,200]獨立的小任務,然后兩兩求和合並。
其實顯然易見負責整個fork/join初始化工作的就是ForkJoinPool!初始化代碼就是那一行 ForkJoinPool forkJoinPool = new ForkJoinPool(),點進去查看源碼。
ForkJoinPool forkJoinPool = new ForkJoinPool();
//最終調用到這段代碼
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism), //並行度,當前機器的cpu核數
checkFactory(factory), //工作線程創建工廠
handler, //異常處理handler
asyncMode ? FIFO_QUEUE : LIFO_QUEUE, //任務隊列出隊模式 異步:先進先出,同步:后進先出
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission();
}
看完初始化的代碼我們可以知道原來創建ForkJoinPool創建workerThread的工作都是統一由一個叫ForkJoinWorkerThreadFactory的工廠去創建,創建出來的線程都有一個統一的前輟名稱"ForkJoinPool-" + nextPoolId() + "-worker-".隊列出隊模式是LIFO(后進先出),那這樣后面的入隊的任務是會被先處理的。所以上面提到對代碼做了一些修改就是先處理rightTask,再處理leftTask。這其實是對代碼的一種優化!
//執行子任務
leftTask.fork();
rightTask.fork();
Integer rightResult = rightTask.join();
Integer leftResult = leftTask.join();
4、任務的提交邏輯?
fork/join其實大部分邏輯處理操作都集中在提交任務和處理任務這兩塊,了解任務的提交基本上后面就很容易理解了。
fork/join提交任務主要分為兩種:
第一種:第一次提交到forkJoinPool
ForkJoinTask<Integer> forkJoinTask = forkJoinPool.submit(countTask);
第二種:任務切分之后的提交
leftTask.fork();
rightTask.fork();
提交到forkJoinPool :
代碼調用路徑 submit(ForkJoinTask<T> task) -> externalPush(ForkJoinTask<?> task) -> externalSubmit(ForkJoinTask<?> task)
下面貼上externalSubmit的詳細代碼,着重留意注釋的部分。
private void externalSubmit(ForkJoinTask<?> task) {
int r; // initialize caller's probe
if ((r = ThreadLocalRandom.getProbe()) == 0) {
ThreadLocalRandom.localInit();
r = ThreadLocalRandom.getProbe();
}
for (;;) { //采用循環入隊的方式
WorkQueue[] ws; WorkQueue q; int rs, m, k;
boolean move = false;
if ((rs = runState) < 0) {
tryTerminate(false, false); // help terminate
throw new RejectedExecutionException();
}
else if ((rs & STARTED) == 0 || // initialize 初始化操作
((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
int ns = 0;
rs = lockRunState();
try {
if ((rs & STARTED) == 0) {
U.compareAndSwapObject(this, STEALCOUNTER, null,
new AtomicLong());
// create workQueues array with size a power of two
int p = config & SMASK; // ensure at least 2 slots //config就是cpu的核數
int n = (p > 1) ? p - 1 : 1;
n |= n >>> 1; n |= n >>> 2; n |= n >>> 4;
n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1; //算出workQueues的大小n,n一定是2的次方數
workQueues = new WorkQueue[n]; //初始化隊列,然后跳到最外面的循環繼續把任務入隊~
ns = STARTED;
}
} finally {
unlockRunState(rs, (rs & ~RSLOCK) | ns);
}
}
else if ((q = ws[k = r & m & SQMASK]) != null) { //選中了一個一個非空隊列
if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) { //利用cas操作加鎖成功!
ForkJoinTask<?>[] a = q.array;
int s = q.top;
boolean submitted = false; // initial submission or resizing
try { // locked version of push
if ((a != null && a.length > s + 1 - q.base) ||
(a = q.growArray()) != null) {
int j = (((a.length - 1) & s) << ASHIFT) + ABASE; //計算出任務在隊列中的位置
U.putOrderedObject(a, j, task); //把任務放在隊列中
U.putOrderedInt(q, QTOP, s + 1); //更新一次存放的位置
submitted = true;
}
} finally {
U.compareAndSwapInt(q, QLOCK, 1, 0); //利用cas操作釋放鎖!
}
if (submitted) {
signalWork(ws, q);
return; //任務入隊成功了!跳出循環!
}
}
move = true; // move on failure
}
else if (((rs = runState) & RSLOCK) == 0) { // create new queue 選中的隊列是空,初始化完隊列,然后繼續入隊!
q = new WorkQueue(this, null);
q.hint = r;
q.config = k | SHARED_QUEUE;
q.scanState = INACTIVE;
rs = lockRunState(); // publish index
if (rs > 0 && (ws = workQueues) != null &&
k < ws.length && ws[k] == null)
ws[k] = q; // else terminated
unlockRunState(rs, rs & ~RSLOCK);
}
else
move = true; // move if busy
if (move)
r = ThreadLocalRandom.advanceProbe(r);
}
}
通過對externalSubmit方法的代碼進行分析,我們知道了第一次提交任務給forkJoinPool時是在無限循環for (;;)中入隊。第一步先檢查workQueues是不是還沒有創建,如果沒有,則進行創建。之后跳到外層for循環並隨機選取workQueues里面一個隊列,並判斷隊列是否已創建。沒有創建,則進行創建!后又跳到外層for循環直到選到一個非空隊列並且加鎖成功!這樣最后才把任務入隊~。
所以我們知道fork/join的任務隊列workQueues並不是初始化的時候就創建好了,而是在有任務提交的時候才創建!並且每次入隊時都需要利用cas操作來進行加鎖和釋放鎖!
任務切分之后的提交:
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this); //workerThread直接入自己的workQueue
else
ForkJoinPool.common.externalPush(this);
return this;
}
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) { //隨機選取了一個非空隊列,並且加鎖成功!下面是普通的入隊過程~
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
signalWork(ws, q);
return; //結束方法
}
U.compareAndSwapInt(q, QLOCK, 1, 0); //一定要釋放鎖!
}
//這個就是上面的externalSummit方法,邏輯是一樣的~
externalSubmit(task);
}
從代碼中我們知道了提交一個fork任務的過程和第一次提交到forkJoinPool的過程是大同小異的。主要區分了提交任務的線程是不是workerThread,如果是,任務直接入workerThread當前的workQueue,不是則嘗試選中一個workQueue q。如果q非空並且加鎖成功則進行入隊,否則執行與第一次任務提交到forkJoinPool差不多的邏輯~。
5、任務的消費
提交到任務的最終目的,是為了消費任務並最終獲取到我們想要的結果。介紹任務消費之前我們先了解一個我們的任務ForkJoinTask有哪些關鍵屬性和方法。
/** The run status of this task */
volatile int status; // accessed directly by pool and workers
static final int DONE_MASK = 0xf0000000; // mask out non-completion bits
static final int NORMAL = 0xf0000000; // must be negative
static final int CANCELLED = 0xc0000000; // must be < NORMAL
static final int EXCEPTIONAL = 0x80000000; // must be < CANCELLED
static final int SIGNAL = 0x00010000; // must be >= 1 << 16
static final int SMASK = 0x0000ffff; // short bits for tags
final int doExec() { //任務的執行入口
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
再看一下RecursiveTask的定義
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
private static final long serialVersionUID = 5232453952276485270L;
/**
* The result of the computation.
*/
V result;
/**
* The main computation performed by this task.
* @return the result of the computation
*/
protected abstract V compute(); //我們實現的處理邏輯
public final V getRawResult() { //獲取返回計算結果
return result;
}
protected final void setRawResult(V value) {
result = value;
}
/**
* Implements execution conventions for RecursiveTask.
*/
protected final boolean exec() {
result = compute(); //存儲計算結果
return true;
}
}
在代碼中我們看到任務的真正執行鏈路是 doExec -> exec -> compute -> 最后設置status 和 result。既然定義狀態status並且還是volatile類型我們可以推斷出workerThread在獲取到執行任務之后都會先判斷status是不是已完成或者異常狀態,才決定要不要處理該任務。
下面看一下任務真正的處理邏輯代碼!
Integer rightResult = rightTask.join()
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
//執行處理前先判斷staus是不是已完成,如果完成了就直接返回
//因為這個任務可能被其它線程竊取過去處理完了
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
代碼的調用鏈是從上到下。整體處理邏輯如下:
線程是workerThread:
先判斷任務是否已經處理完成,任務完成直接返回,沒有則直接嘗試出隊tryUnpush(this) 然后執行任務處理doExec()。如果沒有出隊成功或者處理成功,則執行wt.pool.awaitJoin(w, this, 0L)。wt.pool.awaitJoin(w, this, 0L)的處理邏輯簡單來說也是在一個for(;;)中不斷的輪詢任務的狀態是不是已完成,完成就直接退出方法。否就繼續嘗試出隊處理。直到任務完成或者超時為止。
線程不是workerThread:
直接進行入externalAwaitDone()
private int externalAwaitDone() {
int s = ((this instanceof CountedCompleter) ? // try helping
ForkJoinPool.common.externalHelpComplete(
(CountedCompleter<?>)this, 0) :
ForkJoinPool.common.tryExternalUnpush(this) ? doExec() : 0);
if (s >= 0 && (s = status) >= 0) {
boolean interrupted = false;
do {
if (U.compareAndSwapInt(this, STATUS, s, s | SIGNAL)) {
synchronized (this) {
if (status >= 0) {
try {
wait(0L);
} catch (InterruptedException ie) {
interrupted = true;
}
}
else
notifyAll();
}
}
} while ((s = status) >= 0);
if (interrupted)
Thread.currentThread().interrupt();
}
return s;
externalAwaitDone的處理邏輯其實也比較簡單,當前線程自己先嘗試把任務出隊ForkJoinPool.common.tryExternalUnpush(this) ? doExec()然后處理掉,如果不成功就交給workerThread去處理,然后利用object/wait的經典方法去監聽任務status的狀態變更。
6、任務的竊取
一直說fork/join的任務是work-stealing(工作竊取),那任務究竟是怎么被竊取的呢。我們分析一下任務是由workThread來竊取的,workThread是一個線程。線程的所有邏輯都是由run()方法執行,所以任務的竊取邏輯一定在run()方法中可以找到!
public void run() { //線程run方法
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue); //在這里處理任務隊列!
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
/**
* Top-level runloop for workers, called by ForkJoinWorkerThread.run.
*/
final void runWorker(WorkQueue w) {
w.growArray(); // allocate queue 進行隊列的初始化
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask<?> t;;) { //又是無限循環處理任務!
if ((t = scan(w, r)) != null) //在這里獲取任務!
w.runTask(t);
else if (!awaitWork(w, r))
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
其實只要看下面的英文注釋就知道了大概scan(WorkQueue w, int r)就是用來竊取任務的!
/**
* Scans for and tries to steal a top-level task. Scans start at a
* random location, randomly moving on apparent contention,
* otherwise continuing linearly until reaching two consecutive
* empty passes over all queues with the same checksum (summing
* each base index of each queue, that moves on each steal), at
* which point the worker tries to inactivate and then re-scans,
* attempting to re-activate (itself or some other worker) if
* finding a task; otherwise returning null to await work. Scans
* otherwise touch as little memory as possible, to reduce
* disruption on other scanning threads.
*
* @param w the worker (via its WorkQueue)
* @param r a random seed
* @return a task, or null if none found
*/
private ForkJoinTask<?> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
int b, n; long c;
if ((q = ws[k]) != null) { //隨機選中了非空隊列 q
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + ABASE; //從尾部出隊,b是尾部下標
if ((t = ((ForkJoinTask<?>)
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) { //利用cas出隊
q.base = b + 1;
if (n < -1) // signal others
signalWork(ws, q);
return t; //出隊成功,成功竊取一個任務!
}
}
else if (oldSum == 0 && // try to activate 隊列沒有激活,嘗試激活
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
}
if (ss < 0) // refresh
ss = w.scanState;
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}
//k = k + 1表示取下一個隊列 如果(k + 1) & m == origin表示 已經遍歷完所有隊列了
if ((k = (k + 1) & m) == origin) { // continue until stable
if ((ss >= 0 || (ss == (ss = w.scanState))) &&
oldSum == (oldSum = checkSum)) {
if (ss < 0 || w.qlock < 0) // already inactive
break;
int ns = ss | INACTIVE; // try to inactivate
long nc = ((SP_MASK & ns) |
(UC_MASK & ((c = ctl) - AC_UNIT)));
w.stackPred = (int)c; // hold prev stack top
U.putInt(w, QSCANSTATE, ns);
if (U.compareAndSwapLong(this, CTL, c, nc))
ss = ns;
else
w.scanState = ss; // back out
}
checkSum = 0;
}
}
}
return null;
}
所以我們知道任務的竊取從workerThread運行的那一刻就已經開始了!先隨機選中一條隊列看能不能竊取到任務,取不到則竊取下一條隊列,直接遍歷完一遍所有的隊列,如果都竊取不到就返回null。
以上就是我閱讀fork/join源碼之后總結出來一些心得,寫了那么多我覺得也只是描述了個大概而已,真正詳細有用的東西還需要仔細去閱讀里面的代碼才行。如果大家有興趣的話,不妨也去嘗試一下吧-。-~
