背景
最近在很多JDK源碼中都看到了Treiber stack
這個單詞。
- 比如CompletableFuture中的:
volatile Completion stack; // Top of Treiber stack of dependent actions
- 比如FutureTask中的:
/** Treiber stack of waiting threads */
private volatile WaitNode waiters;
- 比如Phaser中的:
/**
* Wait nodes for Treiber stack representing wait queue
*/
static final class QNode implements ForkJoinPool.ManagedBlocker {
final Phaser phaser;
final int phase;
final boolean interruptible;
final boolean timed;
boolean wasInterrupted;
long nanos;
final long deadline;
volatile Thread thread; // nulled to cancel wait
QNode next;
- 還比如ForkJoinPool中的描述:
* Bits and masks for field ctl, packed with 4 16 bit subfields:
* AC: Number of active running workers minus target parallelism
* TC: Number of total workers minus target parallelism
* SS: version count and status of top waiting thread
* ID: poolIndex of top of Treiber stack of waiters
感覺這種名詞出現的頻率有點高,需要深入了解一下。
名稱由來
Treiber Stack在 R. Kent Treiber在1986年的論文Systems Programming: Coping with Parallelism中首次出現。它是一種無鎖並發棧,其無鎖的特性是基於CAS原子操作實現的。
CompletableFuture源碼實現
CompletableFuture的Treiber stack實現感覺有點復雜,因為有其他邏輯摻雜,代碼不容易閱讀,其實抽象來看,Treiber stack首先是個單向鏈表,鏈表頭部即棧頂元素,在入棧和出現過程中,需要對棧頂元素進行CAS控制,防止多線程情況下數據錯亂。
// Either the result or boxed AltResult
volatile Object result;
// Top of Treiber stack of dependent actions(Treiber stack棧頂元素)
volatile Completion stack;
/** Returns true if successfully pushed c onto stack. */
final boolean tryPushStack(Completion c) {
Completion h = stack;
lazySetNext(c, h);
return UNSAFE.compareAndSwapObject(this, STACK, h, c);
}
/** Unconditionally pushes c onto stack, retrying if necessary. */
final void pushStack(Completion c) {
do {} while (!tryPushStack(c));
}
簡單來看,入棧的步驟如下:
- 嘗試入棧,利用CAS將新的節點作為棧頂元素,新節點next賦值為舊棧頂元素
- 嘗試入棧成功,即結束;入棧失敗,繼續重試上面的操作
FutureTask實現
FutureTask用了Treiber Stack來維護等待任務完成的線程,在FutureTask的任務完成/取消/異常后在finishCompletion鈎子方法中會喚醒棧中等待的線程。
Treiber Stack抽象實現
入棧
void push(Node new) {
do {
} while(!tryPush(new)) // 嘗試入棧
}
boolean tryPush(node) {
Node oldHead = head;
node.next = oldHead; // 新節點next賦值為舊棧頂元素
return CAS(oldHead, node); // 利用CAS將新的節點作為棧頂元素
}
出棧
對於出棧,要做的工作就是將原來的棧頂節點移除,等待垃圾回收;新棧頂元素CAS為第一個子元素。偽代碼:
E pop() {
Node<E> oldHead;
Node<E> newHead;
do {
oldHead = top.get();
// 判斷棧是否為空,為空直接返回
if (oldHead == null)
return null;
newHead = oldHead.next;
} while (!CAS(oldHead, newHead));
// 舊的節點刪掉next引用,等待gc
oldHead.item = null;
return oldHead.item;
}
示例
import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* 基於Unsafe實現TreiberStack
* @author Charles
*/
public class TreiberStack<E> {
private volatile Node<E> head;
public void push(E item) {
Objects.requireNonNull(item);
Node<E> newHead = new Node<>(item);
Node<E> oldHead;
int count = 0;
do {
oldHead = head;
count++;
} while (!tryPush(oldHead, newHead, count));
newHead.next = oldHead;
}
private boolean tryPush(Node<E> oldHead, Node<E> newHead, int count) {
boolean isSuccess = UNSAFE.compareAndSwapObject(this, HEAD, oldHead, newHead);
System.out.println(currentThreadName() + " try push [" + count + "]," +
" oldHead = " + getValue(oldHead) +
" newHead = " + getValue(newHead) +
" isSuccess = " + isSuccess);
return isSuccess;
}
public E pop() {
Node<E> oldHead;
Node<E> newHead;
do {
oldHead = head;
System.out.println(currentThreadName() + " do pop:" +
" oldHead = " + getValue(oldHead) +
" newHead = " + Optional.ofNullable(head).map(s -> s.next.item).orElse(null));
if (oldHead == null) {
return null;
}
newHead = oldHead.next;
} while (!tryPop(oldHead, newHead));
oldHead.next = null;
return oldHead.item;
}
private boolean tryPop(Node<E> oldHead, Node<E> newHead) {
boolean isSuccess = UNSAFE.compareAndSwapObject(this, HEAD, oldHead, newHead);
System.out.println(currentThreadName() + " try pop:" +
" oldHead = " + getValue(oldHead) +
" currentHead = " + getValue(head) +
" newHead = " + getValue(newHead) +
" isSuccess: " + isSuccess);
return isSuccess;
}
private E getValue(Node<E> n) {
return Optional.ofNullable(n).map(t -> t.item).orElse(null);
}
private static class Node<E> {
E item;
Node<E> next;
Node(E item) {
this.item = item;
}
}
// Unsafe mechanics
private static final sun.misc.Unsafe UNSAFE;
private static final long HEAD;
private static final long NEXT;
static {
try {
Field getUnsafe = sun.misc.Unsafe.class.getDeclaredField("theUnsafe");
getUnsafe.setAccessible(true);
UNSAFE = (Unsafe) getUnsafe.get(null);
Class<?> k = TreiberStack.class;
HEAD = UNSAFE.objectFieldOffset(k.getDeclaredField("head"));
NEXT = UNSAFE.objectFieldOffset(TreiberStack.Node.class.getDeclaredField("next"));
} catch (Exception x) {
throw new Error(x);
}
}
private static class RandomValue {
private final Integer value;
public RandomValue() {
this.value = new Random().nextInt(Integer.MAX_VALUE);
}
public Integer getValue() {
return value;
}
@Override
public String toString() {
return value.toString();
}
}
private static String currentThreadName() {
return System.nanoTime() + " / " + Thread.currentThread().getName();
}
public static void main(String[] args) throws InterruptedException {
TreiberStack<RandomValue> ts = new TreiberStack<>();
ExecutorService es = Executors.newFixedThreadPool(10);
Thread.sleep(2000);
for (int i = 0; i < 5; i++) {
es.submit(() -> ts.push(new RandomValue()));
}
for (int i = 0; i < 50; i++) {
es.submit((Runnable) ts::pop);
}
}
}