先上結論
原理
- join 原理:在當前線程中調用另一個線程線程 thread 的 join() 方法時,會調用該 thread 的 wait() 方法,直到這個 thread 執行完畢(JVM在 run() 方法執行完后調用 exit() 方法,而 exit() 方法里調用了 notifyAll() 方法)會調用 notifyAll() 方法主動喚醒當前線程。
源碼如下:
public final void join() throws InterruptedException {
join(0);
}
/**
* 注意這個方法是同步的
*/
public final synchronized void join(long millis)
throws InterruptedException {
long base = System.currentTimeMillis();
long now = 0;
if (millis < 0) {
throw new IllegalArgumentException("timeout value is negative");
}
/**
* join方法默認參數為0,會直接阻塞當前線程
*/
if (millis == 0) {
while (isAlive()) {
wait(0);
}
} else {
while (isAlive()) {
long delay = millis - now;
if (delay <= 0) {
break;
}
wait(delay);
now = System.currentTimeMillis() - base;
}
}
}
public final native boolean isAlive();
}
- countDownLatch 原理:可以理解為一個計數器。在初始化 CountDownLatch 的時候會在類的內部初始化一個int的變量,每當調用 countDownt() 方法的時候這個變量的值減1,而 await() 方法就是去判斷這個變量的值是否為0,是則表示所有的操作都已經完成,否則繼續等待。
源碼如下(源碼比較少,直接全貼出來了,所有中文注釋是我自己加上去的):
public static class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
/**
* 初始化state
*/
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
/**
* 嘗試獲取同步狀態
* 只有當同步狀態為0的時候返回1
*/
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
/**
* 自旋+CAS的方式釋放同步狀態
*/
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
/**
* 初始化一個同步器
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
/**
* 調用同步器的acquireSharedInterruptibly方法,並且是響應中斷的
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* 調用同步器的releaseShared方法去讓state減1
*/
public void countDown() {
sync.releaseShared(1);
}
/**
* 獲取剩余的count
*/
public long getCount() {
return sync.getCount();
}
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
區別及注意事項
- join和countDownLatch都能實現讓當前線程阻塞等待其他線程執行完畢,join使用起來更簡便,不過countDownLatch粒度更細。
- 由於CountDownLatch需要開發人員很明確需要等待的條件,否則容易造成await()方法一直阻塞。
如何使用
- 一個簡單的小例子
public class Test {
private static final Logger logger = LoggerFactory.getLogger(Test.class);
public static void main(String[] args) {
long sleepTime = 5000;
try {
TestJoinThread joinThread1 = new TestJoinThread("joinThread1",sleepTime);
TestJoinThread joinThrad2 = new TestJoinThread("joinThrad2",sleepTime);
joinThread1.start();
joinThrad2.start();
joinThread1.join();
joinThrad2.join();
logger.info("主線程開始運行...");
} catch (InterruptedException e) {
logger.error("test join err!",e);
}
try {
CountDownLatch count = new CountDownLatch(2);
TestCountDownLatchThread countDownLatchThread1 = new TestCountDownLatchThread(count,"countDownLatchThread1",sleepTime);
TestCountDownLatchThread countDownLatchThread2 = new TestCountDownLatchThread(count,"countDownLatchThread2",sleepTime);
countDownLatchThread1.start();
countDownLatchThread2.start();
count.await();
logger.info("主線程開始運行...");
} catch (InterruptedException e) {
logger.error("test countDownLatch err!",e);
}
}
static class TestJoinThread extends Thread{
private String threadName;
private long sleepTime;
public TestJoinThread(String threadName,long sleepTime){
this.threadName = threadName;
this.sleepTime = sleepTime;
}
@Override
public void run() {
try{
logger.info(String.format("線程[%s]開始運行...",threadName));
Thread.sleep(sleepTime);
logger.info(String.format("線程[%s]運行結束 耗時[%s]s",threadName,sleepTime/1000));
}catch (Exception e){
logger.error("TestJoinThread run err!",e);
}
}
}
static class TestCountDownLatchThread extends Thread{
private String threadName;
private long sleepTime;
private CountDownLatch countDownLatch;
public TestCountDownLatchThread(CountDownLatch countDownLatch,String threadName,long sleepTime){
this.countDownLatch = countDownLatch;
this.threadName = threadName;
this.sleepTime = sleepTime;
}
@Override
public void run() {
try{
logger.info(String.format("線程[%s]開始運行...",threadName));
Thread.sleep(sleepTime);
logger.info(String.format("線程[%s]運行結束 耗時[%s]s",threadName,sleepTime/1000));
countDownLatch.countDown();
}catch (Exception e){
logger.error("TestCountDownLatchThread run err!",e);
}
}
}
}
日志輸出:
11:18:01.985 [Thread-1] INFO com.sync.Test - 線程[joinThrad2]開始運行...
11:18:01.985 [Thread-0] INFO com.sync.Test - 線程[joinThread1]開始運行...
11:18:06.993 [Thread-1] INFO com.sync.Test - 線程[joinThrad2]運行結束...耗時[5]s
11:18:06.993 [Thread-0] INFO com.sync.Test - 線程[joinThread1]運行結束...耗時[5]s
11:18:06.993 [main] INFO com.sync.Test - 主線程開始運行...
11:18:06.995 [Thread-2] INFO com.sync.Test - 線程[countDownLatchThread1]開始運行...
11:18:06.995 [Thread-3] INFO com.sync.Test - 線程[countDownLatchThread2]開始運行...
11:18:11.996 [Thread-2] INFO com.sync.Test - 線程[countDownLatchThread1]運行結束...耗時[5]s
11:18:11.996 [Thread-3] INFO com.sync.Test - 線程[countDownLatchThread2]運行結束...耗時[5]s
11:18:11.996 [main] INFO com.sync.Test - 主線程開始運行...
可以看到:joinThread1 和 joinThread2 同時開始執行,5s后主線程開始執行。countDownLatchThread1 和 countDownLatchThread2 也是一樣的效果。
那么我上面所說的粒度更細有怎樣的應用場景呢?
我對 TestCountDownLatchThread類 的 run() 方法做一點小改動:
@Override
public void run() {
try{
logger.info(String.format("線程[%s]第一階段開始運行...",threadName);
Thread.sleep(sleepTime);
logger.info(String.format("線程[%s]第一階段運行結束耗時[%s]s",threadName,sleepTime/1000));
countDownLatch.countDown();
logger.info(String.format("線程[%s]第二階段開始運行...",threadName);
Thread.sleep(sleepTime);
logger.info(String.format("線程[%s]第二階段運行結束耗時[%s]s",threadName,sleepTime/1000));
}catch (Exception e){
logger.error("TestCountDownLatchThread run err!",e);
}
}
這個時候日志輸出會變成這樣:
12:59:35.912 [Thread-1] INFO com.sync.Test - 線程[countDownLatchThread2]第一階段開始運行...
12:59:35.912 [Thread-0] INFO com.sync.Test - 線程[countDownLatchThread1]第一階段開始運行...
12:59:40.916 [Thread-0] INFO com.sync.Test - 線程[countDownLatchThread1]第一階段運行結束 耗時[5]s
12:59:40.916 [Thread-1] INFO com.sync.Test - 線程[countDownLatchThread2]第一階段運行結束 耗時[5]s
12:59:40.916 [main] INFO com.sync.Test - 主線程開始運行...
12:59:40.916 [Thread-0] INFO com.sync.Test - 線程[countDownLatchThread1]第二階段開始運行...
12:59:40.916 [Thread-1] INFO com.sync.Test - 線程[countDownLatchThread2]第二階段開始運行...
12:59:45.917 [Thread-0] INFO com.sync.Test - 線程[countDownLatchThread1]第二階段運行結束 耗時[5]s
12:59:45.917 [Thread-1] INFO com.sync.Test - 線程[countDownLatchThread2]第二階段運行結束 耗時[5]s
也就是說如果當前線程只需要等待其他線程一部分任務執行完畢的情況下就可以用 countDownLatch 來實現了,而 join 則實現不了這種粒度的控制。