【JAVA】ThreadLocal源碼分析


ThreadLocal內部是用一張哈希表來存儲:

 1 static class ThreadLocalMap {
 2     static class Entry extends WeakReference<ThreadLocal<?>> {
 3             /** The value associated with this ThreadLocal. */
 4             Object value;
 5 
 6             Entry(ThreadLocal<?> k, Object v) {
 7                 super(k);
 8                 value = v;
 9             }
10     }
11     private static final int INITIAL_CAPACITY = 16;
12     private Entry[] table;
13     private int size = 0;
14     private int threshold;
15     ......

看過HashMap的話就很容易理解上述內容【Java】HashMap源碼分析

而在Thread類中有一個ThreadLocalMap 的成員:

1 ThreadLocal.ThreadLocalMap threadLocals = null;

所以不難得出如下關系:

每一個線程都有一張線程私有的Map,存放多個線程本地變量

set()方法:

 1 public void set(T value) {
 2         Thread t = Thread.currentThread();
 3         ThreadLocalMap map = getMap(t);
 4         if (map != null)
 5             map.set(this, value);
 6         else
 7             createMap(t, value);
 8 }
 9 
10 ThreadLocalMap getMap(Thread t) {
11         return t.threadLocals;
12 }

不難看出,先獲取當前線程的Thread對象,再得到該Thread對象的ThreadLocalMap 成員map,若map為空,需要先createMap()方法,若不為空,則需要調用map的set()方法

 1 void createMap(Thread t, T firstValue) {
 2         t.threadLocals = new ThreadLocalMap(this, firstValue);
 3 }
 4 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
 5             table = new Entry[INITIAL_CAPACITY];
 6             int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
 7             table[i] = new Entry(firstKey, firstValue);
 8             size = 1;
 9             setThreshold(INITIAL_CAPACITY);
10 }
11 private void setThreshold(int len) {
12             threshold = len * 2 / 3;
13 }

createMap方法會創建一個ThreadLocalMap對象,在ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue)構造方法中,可以看出和HashMap很相似,通過firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)取模,計算出哈希表的下標,將創建好的Entry對象放入該位置,再根據表長計算閾值,可以看出負載因子是2/3,初始哈希表的大小是16。

 1 private void set(ThreadLocal<?> key, Object value) {
 2     Entry[] tab = table;
 3     int len = tab.length;
 4     int i = key.threadLocalHashCode & (len-1);
 5     
 6     for (Entry e = tab[i];
 7          e != null;
 8          e = tab[i = nextIndex(i, len)]) {
 9         ThreadLocal<?> k = e.get();
10     
11         if (k == key) {
12             e.value = value;
13             return;
14         }
15     
16         if (k == null) {
17             replaceStaleEntry(key, value, i);
18             return;
19         }
20     }
21     
22     tab[i] = new Entry(key, value);
23     int sz = ++size;
24     if (!cleanSomeSlots(i, sz) && sz >= threshold)
25         rehash();
26 }

不難看出,通過key.threadLocalHashCode & (len-1)計算出哈希表的下標,判斷該位置的Entry是否為null,若為null,則創建Entry對象,將其放入該下標位置;若Entry已存在,則需要解決哈希沖突,重新計算下標。最后size自增,再根據!cleanSomeSlots(i, sz) && sz >= threshold進行判斷是否需要進行哈希表的調整。

在解決哈希沖突的上,常用的有開鏈法、線性探測法和再散列法,HashMap中使用的是開鏈法,而ThreadLocal使用的是線性探測法,即發生哈希沖突,往后移動到合適位置。

1 private static int nextIndex(int i, int len) {
2             return ((i + 1 < len) ? i + 1 : 0);
3 }
4 private static int prevIndex(int i, int len) {
5             return ((i - 1 >= 0) ? i - 1 : len - 1);
6 }

從這兩個操作看出,ThreadLocal中的哈希表是利用了循環數組的方式,進行環形的線性探測
在上述for循環中,會取出該Entry上的ThreadLocal對象(鍵)進行判斷,若相同則直接覆蓋,若為null,說明該Entry空間存在但其ThreadLocal對象的指向為null,需要進行調整;若都不成立,則繼續循環,重復以上操作。

Entry空間指向存在但ThreadLocal對象的指向為null是因為Entry繼承自WeakReference<ThreadLocal<?>>,是弱引用,存在被GC的情況,所以會存在這種情況,視為臟Entry,接下來的操作就是通過replaceStaleEntry進行處理。

 1 private void replaceStaleEntry(ThreadLocal<?> key, Object value,
 2                                        int staleSlot) {
 3     Entry[] tab = table;
 4     int len = tab.length;
 5     Entry e;
 6     
 7     int slotToExpunge = staleSlot;
 8     for (int i = prevIndex(staleSlot, len);
 9          (e = tab[i]) != null;
10          i = prevIndex(i, len))
11         if (e.get() == null)
12             slotToExpunge = i;
13             
14     for (int i = nextIndex(staleSlot, len);
15          (e = tab[i]) != null;
16          i = nextIndex(i, len)) {
17         ThreadLocal<?> k = e.get();
18     
19         if (k == key) {
20             e.value = value;
21     
22             tab[i] = tab[staleSlot];
23             tab[staleSlot] = e;
24     
25             if (slotToExpunge == staleSlot)
26                 slotToExpunge = i;
27             cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
28             return;
29         }
30     
31         if (k == null && slotToExpunge == staleSlot)
32             slotToExpunge = i;
33     }
34     
35     tab[staleSlot].value = null;
36     tab[staleSlot] = new Entry(key, value);
37     
38     if (slotToExpunge != staleSlot)
39         cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
40 }

可以清楚看到第一個for循環前向遍歷查找臟Entry,用slotToExpunge保存臟Entry下標;
第二個for循環后向遍歷,若遇到ThreadLocal向同,更新value,然后與下標為staleSlot(傳入進來的臟Entry)進行交換,接着判斷前向查找臟Entry是否存在,slotToExpunge == staleSlot說明的就是前向查找沒找到,就更改slotToExpunge的值,然后進行清理操作,結束掉;若后向遍歷遇到臟Entry,並且前向沒找到,更改slotToExpunge的值,為清理時用,繼續循環。
若不存在和ThreadLocal引用相同的Entry,則需要將staleSlot的位置的Entry替換為一個新的Entry對象,tab[staleSlot].value = null是為了GC;
最后根據slotToExpunge來判斷前向后向遍歷中是否存在臟Entry,若存在還需要進行清理。

其中的expungeStaleEntry方法如下:

 1 private int expungeStaleEntry(int staleSlot) {
 2     Entry[] tab = table;
 3     int len = tab.length;
 4     
 5     // expunge entry at staleSlot
 6     tab[staleSlot].value = null;
 7     tab[staleSlot] = null;
 8     size--;
 9     
10     // Rehash until we encounter null
11     Entry e;
12     int i;
13     for (i = nextIndex(staleSlot, len);
14          (e = tab[i]) != null;
15          i = nextIndex(i, len)) {
16         ThreadLocal<?> k = e.get();
17         if (k == null) {
18             e.value = null;
19             tab[i] = null;
20             size--;
21         } else {
22             int h = k.threadLocalHashCode & (len - 1);
23             if (h != i) {
24                 tab[i] = null;
25     
26                 // Unlike Knuth 6.4 Algorithm R, we must scan until
27                 // null because multiple entries could have been stale.
28                 while (tab[h] != null)
29                     h = nextIndex(h, len);
30                 tab[h] = e;
31             }
32         }
33     }
34     return i;
35 }

可以看到,先把當前位置的臟Entry清除掉(置為null),size自減。然后從當前位置后向遍歷,若遇到臟Entry直接清除,size自減;若不是臟Entry,則需要判斷它是否經過哈希沖突的調整的,若調整過,需要將其重新調整,最后返回當前位置為null的table下標;綜上,該方法就是后向清除臟Entry,再把調整需要調整的Entry。

在replaceStaleEntry方法中,調用expungeStaleEntry清除掉臟Entry后,還要用cleanSomeSlots方法清除掉返回回來的下標后的臟Entry;

cleanSomeSlots方法:

 1 private boolean cleanSomeSlots(int i, int n) {
 2     boolean removed = false;
 3     Entry[] tab = table;
 4     int len = tab.length;
 5     do {
 6         i = nextIndex(i, len);
 7         Entry e = tab[i];
 8         if (e != null && e.get() == null) {
 9             n = len;
10             removed = true;
11             i = expungeStaleEntry(i);
12         }
13     } while ( (n >>>= 1) != 0);
14     return removed;
15 }

從下標為i后面的開始后向遍歷,遇到臟Entry調用expungeStaleEntry清除掉,令removed為true,i會變為下標為null的位置,繼續循環;其中n的用途是控制循環次數,當遇到臟Entry時,會令n等於表長,擴大搜索范圍。

在set方法中,最后根據!cleanSomeSlots(i, sz) && sz >= threshold,判斷是否清理掉了臟Entry,若清理了什么都不做;若沒有清理,還會判斷是否達到閾值,進而是否需要rehash操作;

rehash方法:

1 private void rehash() {
2     expungeStaleEntries();
3     
4     // Use lower threshold for doubling to avoid hysteresis
5     if (size >= threshold - threshold / 4)
6         resize();
7 }

首先調用expungeStaleEntries方法:

1 private void expungeStaleEntries() {
2     Entry[] tab = table;
3     int len = tab.length;
4     for (int j = 0; j < len; j++) {
5         Entry e = tab[j];
6         if (e != null && e.get() == null)
7             expungeStaleEntry(j);
8     }
9 }

可以看到expungeStaleEntries方法是遍歷整個哈希表,通過調用expungeStaleEntry方法清除掉所有臟Entry。
由於清除掉了臟Entry,還需要對size進行判斷,看是否達到了閾值的3/4(提前觸發resize),來判斷是否真的需要resize;

resize方法:

 1 private void resize() {
 2     Entry[] oldTab = table;
 3     int oldLen = oldTab.length;
 4     int newLen = oldLen * 2;
 5     Entry[] newTab = new Entry[newLen];
 6     int count = 0;
 7     
 8     for (int j = 0; j < oldLen; ++j) {
 9         Entry e = oldTab[j];
10         if (e != null) {
11             ThreadLocal<?> k = e.get();
12             if (k == null) {
13                 e.value = null; // Help the GC
14             } else {
15                 int h = k.threadLocalHashCode & (newLen - 1);
16                 while (newTab[h] != null)
17                     h = nextIndex(h, newLen);
18                 newTab[h] = e;
19                 count++;
20             }
21         }
22     }
23     
24     setThreshold(newLen);
25     size = count;
26     table = newTab;
27 }

剛開始的操作可以清楚的明白,每次擴容的大小都是原來的兩倍;然后遍歷原表的所有Entry,遇到臟Entry直接賦值null引起幫助GC;遇到有效Entry則需要根據新的表長重新計算下標,再通過線性探測完成新表的填充;填充完畢,計算新的閾值,給size和table賦值,結束操作。

至此,有關set的操作就結束了,還剩下get和remove:

get方法:

 1 public T get() {
 2     Thread t = Thread.currentThread();
 3     ThreadLocalMap map = getMap(t);
 4     if (map != null) {
 5         ThreadLocalMap.Entry e = map.getEntry(this);
 6         if (e != null) {
 7             @SuppressWarnings("unchecked")
 8             T result = (T)e.value;
 9             return result;
10         }
11     }
12     return setInitialValue();
13 }

和set一樣,先獲取當前線程,再根據當前線程獲取其ThreadLocalMap成員map;
若map不為null,通過map的getEntry方法得到Entry對象,若Entry不為null則直接返回Entry的value;
若map為null,或者map不為null,但是Entry是null,則都需要調用setInitialValue方法。

getEntry方法:

1 private Entry getEntry(ThreadLocal<?> key) {
2     int i = key.threadLocalHashCode & (table.length - 1);
3     Entry e = table[i];
4     if (e != null && e.get() == key)
5         return e;
6     else
7         return getEntryAfterMiss(key, i, e);
8 }

根據ThreadLocal定位哈希表的下標,若滿足則直接返回,若不是,調用getEntryAfterMiss繼續找。

getEntryAfterMiss方法:

 1 private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
 2     Entry[] tab = table;
 3     int len = tab.length;
 4     
 5     while (e != null) {
 6         ThreadLocal<?> k = e.get();
 7         if (k == key)
 8             return e;
 9         if (k == null)
10             expungeStaleEntry(i);
11         else
12             i = nextIndex(i, len);
13         e = tab[i];
14     }
15     return null;
16 }

看以看到這還是一個后向遍歷的查找,若是找到則直接返回;若遇到臟Entry需要調用expungeStaleEntry方法清理掉;最后還沒找到返回null。

setInitialValue方法:

 1 private T setInitialValue() {
 2     T value = initialValue();
 3     Thread t = Thread.currentThread();
 4     ThreadLocalMap map = getMap(t);
 5     if (map != null)
 6        map.set(this, value);
 7     else
 8        createMap(t, value);
 9     return value;
10 }

先調用initialValue方法,該方法需要使用者進行覆蓋,否則返回的是null。所以當沒有使用set方法時覆蓋initialValue方法時還是會調用set方法的,效果是一樣的。

1 protected T initialValue() {
2         return null;
3 }

后面的操作就和set方法一樣。get方法至此結束。

remove方法:

1 public void remove() {
2     ThreadLocalMap m = getMap(Thread.currentThread());
3     if (m != null)
4         m.remove(this);
5 }

以當前線程為參數調用getMap方法:

1 ThreadLocalMap getMap(Thread t) {
2     return t.threadLocals;
3 }

若是當前線程的ThreadLocalMap對象不存在,什么都不做,若存在,調用內部的remove方法:

 1 private void remove(ThreadLocal<?> key) {
 2     Entry[] tab = table;
 3     int len = tab.length;
 4     int i = key.threadLocalHashCode & (len-1);
 5     for (Entry e = tab[i];
 6          e != null;
 7          e = tab[i = nextIndex(i, len)]) {
 8         if (e.get() == key) {
 9             e.clear();
10             expungeStaleEntry(i);
11             return;
12         }
13     }
14 }

首先根據ThreadLocal找到其對應的的哈希表的下標(不一定是它的下標,會有哈希沖突的可能性),然后開始后向遍歷,找到真正的位置,調用clear方法刪除掉,順便還進行臟Entry的清理。

clear方法是Reference類的方法:

1 public void clear() {
2     this.referent = null;
3 }

可以看到僅僅只是令指向變為null,因為Reference是WeakReference的父類,ThreadLocalMap繼承自WeakReference<ThreadLocal<?>>,弱引用變為null,就會變成臟Entry,所以就需要expungeStaleEntry對其清理。為什么不令tab[i]直接為null,就是因為在expungeStaleEntry執行時還會清理遇到的臟Entry,這樣可以盡可能多的刪除掉臟Entry。

ThreadLocal源碼分析到此結束。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM