CountDownLatch概述
日常開發中,經常會遇到類似場景:主線程開啟多個子線程執行任務,需要等待所有子線程執行完畢后再進行匯總。
在同步組件CountDownLatch出現之前,我們可以使用join方法來完成,簡單實現如下:
public class JoinTest {
public static void main(String[] args) throws InterruptedException {
Thread A = new Thread(() -> {
try {
Thread.sleep(1000);
System.out.println("A finish!");
} catch (InterruptedException e) {
e.printStackTrace();
}
});
Thread B = new Thread(() -> {
try {
Thread.sleep(1000);
System.out.println("B finish!");
} catch (InterruptedException e) {
e.printStackTrace();
}
});
System.out.println("main thread wait ..");
A.start();
B.start();
A.join(); // 等待A執行結束
B.join(); // 等待B執行結束
System.out.println("all thread finish !");
}
}
但使用join方法並不是很靈活,並不能很好地滿足某些場景的需要,而CountDownLatch則能夠很好地代替它,並且相比之下,提供了更多靈活的特性:
CountDownLatch相比join方法對線程同步有更靈活的控制,原因如下:
- 調用子線程的join()方法后,該線程會一直被阻塞直到子線程運行完畢,而CountDownLatch使用計數器來允許子線程運行完畢或者運行中遞減計數,await方法返回不一定必須等待線程結束。
- 使用線程池管理線程時,添加Runnable到線程池,沒有辦法再調用線程的join方法了。
使用案例與基本思路
public class TestCountDownLatch {
public static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main (String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
executorService.submit(() -> {
try {
Thread.sleep(1000);
System.out.println("A finish!");
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
});
executorService.submit(() -> {
try {
Thread.sleep(1000);
System.out.println("B finish!");
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
});
System.out.println("main thread wait ..");
countDownLatch.await();
System.out.println("all thread finish !");
executorService.shutdown();
}
}
// 結果
main thread wait ..
B finish!
A finish!
all thread finish !
- 構建CountDownLatch實例,構造參數傳參為2,內部計數初始值為2。
- 主線程構建線程池,提交兩個任務,接着調用
countDownLatch.await()
陷入阻塞。 - 子線程執行完畢之后調用
countDownLatch.countDown()
,內部計數器減1。 - 所有子線程執行完畢之后,計數為0,此時主線程的await方法返回。
類圖與基本結構
public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
//...
}
private final Sync sync;
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void countDown() {
sync.releaseShared(1);
}
public long getCount() {
return sync.getCount();
}
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
CountDownLatch基於AQS實現,內部維護一個Sync變量,繼承了AQS。
在AQS中,最重要的就是state狀態的表示,在CountDownLatch中使用state表示計數器的值,在初始化的時候,為state賦值。
幾個同步方法實現比較簡單,如果你不熟悉AQS,推薦你瞅一眼前置文章:
- Java並發包源碼學習系列:AbstractQueuedSynchronizer
- Java並發包源碼學習系列:CLH同步隊列及同步資源獲取與釋放
- Java並發包源碼學習系列:AQS共享式與獨占式獲取與釋放資源的區別
- Java並發包源碼學習系列:詳解Condition條件隊列、signal和await
- Java並發包源碼學習系列:掛起與喚醒線程LockSupport工具類
接下來我們簡單看一看實現,主要學習兩個方法:await()和countdown()。
void await()
當線程調用CountDownLatch的await方法后,線程會被阻塞,除非發生下面兩種情況:
- 內部計數器值為0,
getState() == 0
。 - 被其他線程中斷,拋出異常,也就是
currThread.interrupt()
。
// CountDownLatch.java
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// AQS.java
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果線程中斷, 則拋出異常
if (Thread.interrupted())
throw new InterruptedException();
// 由子類實現,這里再Sync中實現,計數器為0就可以返回,否則進入AQS隊列等待
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
// Sync
// 計數器為0 返回1, 否則返回-1
private static final class Sync extends AbstractQueuedSynchronizer {
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
}
boolean await(long timeout, TimeUnit unit)
當線程調用CountDownLatch的await方法后,線程會被阻塞,除非發生下面三種情況:
- 內部計數器值為0,
getState() == 0
,返回true。 - 被其他線程中斷,拋出異常,也就是
currThread.interrupt()
。 - 設置的timeout時間到了,超時返回false。
// CountDownLatch.java
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// AQS.java
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
void countDown()
調用該方法,內部計數值減1,遞減后如果計數器值為0,喚醒所有因調用await方法而被阻塞的線程,否則跳過。
// CountDownLatch.java
public void countDown() {
sync.releaseShared(1);
}
// AQS.java
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
// Sync
private static final class Sync extends AbstractQueuedSynchronizer {
protected boolean tryReleaseShared(int releases) {
// 循環進行CAS操作
for (;;) {
int c = getState();
// 一旦為0,就返回false
if (c == 0)
return false;
int nextc = c-1;
// CAS嘗試將state-1,只有這一步CAS成功且將state變成0的線程才會返回true
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
總結
-
CountDownLatch相比於join方法更加靈活且方便地實現線程間同步,體現在以下幾點:
- 調用子線程的join()方法后,該線程會一直被阻塞直到子線程運行完畢,而CountDownLatch使用計數器來允許子線程運行完畢或者運行中遞減計數,await方法返回不一定必須等待線程結束。
- 使用線程池管理線程時,添加Runnable到線程池,沒有辦法再調用線程的join方法了。
-
CountDownLatch使用state表示內部計數器的值,初始化傳入count。
-
線程調用countdown方法將會原子性地遞減AQS的state值,線程調用await方法后將會置入AQS阻塞隊列中,直到計數器為0,或被打斷,或超時等才會返回,計數器為0時,當前線程還需要喚醒由於await()被阻塞的線程。
參考閱讀
- 《Java並發編程之美》