【Java 並發】詳解 ThreadLocal


前言

ThreadLocal 主要用來提供線程局部變量,也就是變量只對當前線程可見,本文主要記錄一下對於 ThreadLocal 的理解。更多關於 Java 多線程的文章可以轉到 這里

線程局部變量

在多線程環境下,之所以會有並發問題,就是因為不同的線程會同時訪問同一個共享變量,例如下面的形式

public class MultiThreadDemo {

    public static class Number {
        private int value = 0;

        public void increase() throws InterruptedException {
            value = 10;
            Thread.sleep(10);
            System.out.println("increase value: " + value);
        }

        public void decrease() throws InterruptedException {
            value = -10;
            Thread.sleep(10);
            System.out.println("decrease value: " + value);
        }
    }

    public static void main(String[] args) throws InterruptedException {
        final Number number = new Number();
        Thread increaseThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    number.increase();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });

        Thread decreaseThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    number.decrease();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });

        increaseThread.start();
        decreaseThread.start();
    }
}

在上面的代碼中,increase 線程和 decrease 線程會操作同一個 number 中 value,那么輸出的結果是不可預測的,因為當前線程修改變量之后但是還沒輸出的時候,變量有可能被另外一個線程修改,下面是一種可能的情況:

increase value: 10
decrease value: 10

一種解決方法是在 increase()decrease() 方法上加上 synchronized 關鍵字進行同步,這種做法其實是將 value 的 賦值打印 包裝成了一個原子操作,也就是說兩者要么同時進行,要不都不進行,中間不會有額外的操作。我們換個角度考慮問題,如果 value 只屬於 increase 線程或者 decrease 線程,而不是被兩個線程共享,那么也不會出現競爭問題。一種比較常見的形式就是局部(local)變量(這里排除局部變量引用指向共享對象的情況),如下所示:

public void increase() throws InterruptedException {
    int value = 10;
    Thread.sleep(10);
    System.out.println("increase value: " + value);
}

不論 value 值如何改變,都不會影響到其他線程,因為在每次調用 increase 方法時,都會創建一個 value 變量,該變量只對當前調用 increase 方法的線程可見。借助於這種思想,我們可以對每個線程創建一個共享變量的副本,該副本只對當前線程可見(可以認為是線程私有的變量),那么修改該副本變量時就不會影響到其他的線程。一個簡單的思路是使用 Map 存儲每個變量的副本,將當前線程的 id 作為 key,副本變量作為 value 值,下面是一個實現:

public class SimpleImpl {

    public static class CustomThreadLocal {
        private Map<Long, Integer> cacheMap = new HashMap<>();

        private int defaultValue ;

        public CustomThreadLocal(int value) {
            defaultValue = value;
        }

        public Integer get() {
            long id = Thread.currentThread().getId();
            if (cacheMap.containsKey(id)) {
                return cacheMap.get(id);
            }
            return defaultValue;
        }

        public void set(int value) {
            long id = Thread.currentThread().getId();
            cacheMap.put(id, value);
        }
    }

    public static class Number {
        private CustomThreadLocal value = new CustomThreadLocal(0);

        public void increase() throws InterruptedException {
            value.set(10);
            Thread.sleep(10);
            System.out.println("increase value: " + value.get());
        }

        public void decrease() throws InterruptedException {
            value.set(-10);
            Thread.sleep(10);
            System.out.println("decrease value: " + value.get());
        }
    }

    public static void main(String[] args) throws InterruptedException {
        final Number number = new Number();
        Thread increaseThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    number.increase();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });

        Thread decreaseThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    number.decrease();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        });

        increaseThread.start();
        decreaseThread.start();
    }
}

但是上面的實現會存在下面的問題:

  • 每個線程對應的副本變量的生命周期不是由線程決定的,而是由共享變量的生命周期決定的。在上面的例子中,即便線程執行完,只要 number 變量存在,線程的副本變量依然會存在(存放在 number 的 cacheMap 中)。但是作為特定線程的副本變量,該變量的生命周期應該由線程決定,線程消亡之后,該變量也應該被回收。
  • 多個線程有可能會同時操作 cacheMap,需要對 cacheMap 進行同步處理。

為了解決上面的問題,我們換種思路,每個線程創建一個 Map,存放當前線程中副本變量,用 CustomThreadLocal 的實例作為 key 值,下面是一個示例:

public class SimpleImpl2 {

    public static class CommonThread extends Thread {
        Map<Integer, Integer> cacheMap = new HashMap<>();
    }

    public static class CustomThreadLocal {
        private int defaultValue;

        public CustomThreadLocal(int value) {
            defaultValue = value;
        }

        public Integer get() {
            Integer id = this.hashCode();
            Map<Integer, Integer> cacheMap = getMap();
            if (cacheMap.containsKey(id)) {
                return cacheMap.get(id);
            }
            return defaultValue;
        }

        public void set(int value) {
            Integer id = this.hashCode();
            Map<Integer, Integer> cacheMap = getMap();
            cacheMap.put(id, value);
        }

        public Map<Integer, Integer> getMap() {
            CommonThread thread = (CommonThread) Thread.currentThread();
            return thread.cacheMap;
        }
    }

    public static class Number {
        private CustomThreadLocal value = new CustomThreadLocal(0);

        public void increase() throws InterruptedException {
            value.set(10);
            Thread.sleep(10);
            System.out.println("increase value: " + value.get());
        }

        public void decrease() throws InterruptedException {
            value.set(-10);
            Thread.sleep(10);
            System.out.println("decrease value: " + value.get());
        }
    }


    public static void main(String[] args) throws InterruptedException {
        final Number number = new Number();
        Thread increaseThread = new CommonThread() {
            @Override
            public void run() {
                try {
                    number.increase();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }

            }
        };

        Thread decreaseThread = new CommonThread() {
            @Override
            public void run() {
                try {
                    number.decrease();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        };
        increaseThread.start();
        decreaseThread.start();
    }
}

在上面的實現中,當線程消亡之后,線程中 cacheMap 也會被回收,它當中存放的副本變量也會被全部回收,並且 cacheMap 是線程私有的,不會出現多個線程同時訪問一個 cacheMap 的情況。在 Java 中,ThreadLocal 類的實現就是采用的這種思想,注意只是思想,實際的實現和上面的並不一樣。

使用示例

Java 使用 ThreadLocal 類來實現線程局部變量模式,ThreadLocal 使用 set 和 get 方法設置和獲取變量,下面是函數原型:

public void set(T value);
public T get();

下面是使用 ThreadLocal 的一個完整示例:

public class ThreadLocalDemo {
    private static ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
    private static int value = 0;

    public static class ThreadLocalThread implements Runnable {
        @Override
        public void run() {
            threadLocal.set((int)(Math.random() * 100));
            value = (int) (Math.random() * 100);
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.printf(Thread.currentThread().getName() + ": threadLocal=%d, value=%d\n", threadLocal.get(), value);
        }
    }

    public static void main(String[] args) throws InterruptedException {
        Thread thread = new Thread(new ThreadLocalThread());
        Thread thread2 = new Thread(new ThreadLocalThread());
        thread.start();
        thread2.start();
        thread.join();
        thread2.join();
    }
}

下面是一種可能的輸出:

Thread-0: threadLocal=87, value=15
Thread-1: threadLocal=69, value=15

我們看到雖然 threadLocal 是靜態變量,但是每個線程都有自己的值,不會受到其他線程的影響。

具體實現

ThreadLocal 的實現思想,我們在前面已經說了,每個線程維護一個 ThreadLocalMap 的映射表,映射表的 key 是 ThreadLocal 實例本身,value 是要存儲的副本變量。ThreadLocal 實例本身並不存儲值,它只是提供一個在當前線程中找到副本值的 key。 如下圖所示:

圖片來自 http://blog.xiaohansong.com/2016/08/06/ThreadLocal-memory-leak/

我們從下面三個方面看下 ThreadLocal 的實現:

  • 存儲線程副本變量的數據結構
  • 如何存取線程副本變量
  • 如何對 ThreadLocal 的實例進行 Hash

ThreadLocalMap

線程使用 ThreadLocalMap 來存儲每個線程副本變量,它是 ThreadLocal 里的一個靜態內部類。ThreadLocalMap 也是采用的散列表(Hash)思想來實現的,但是實現方式和 HashMap 不太一樣。我們首先看下散列表的相關知識:

散列表

理想狀態下,散列表就是一個包含關鍵字的固定大小的數組,通過使用散列函數,將關鍵字映射到數組的不同位置。下面是理想散列表的一個示意圖:

圖片來自 數據結構與算法分析: C語法描述

在理想狀態下,哈希函數可以將關鍵字均勻的分散到數組的不同位置,不會出現兩個關鍵字散列值相同(假設關鍵字數量小於數組的大小)的情況。但是在實際使用中,經常會出現多個關鍵字散列值相同的情況(被映射到數組的同一個位置),我們將這種情況稱為散列沖突。為了解決散列沖突,主要采用下面兩種方式:

  • 分離鏈表法(separate chaining)
  • 開放定址法(open addressing)

分離鏈表法
分散鏈表法使用鏈表解決沖突,將散列值相同的元素都保存到一個鏈表中。當查詢的時候,首先找到元素所在的鏈表,然后遍歷鏈表查找對應的元素。下面是一個示意圖:

圖片來自 http://faculty.cs.niu.edu/~freedman/340/340notes/340hash.htm

開放定址法
開放定址法不會創建鏈表,當關鍵字散列到的數組單元已經被另外一個關鍵字占用的時候,就會嘗試在數組中尋找其他的單元,直到找到一個空的單元。探測數組空單元的方式有很多,這里介紹一種最簡單的 -- 線性探測法。線性探測法就是從沖突的數組單元開始,依次往后搜索空單元,如果到數組尾部,再從頭開始搜索(環形查找)。如下圖所示:

圖片來自 http://alexyyek.github.io/2014/12/14/hashCollapse/

關於兩種方式的比較,可以參考 這篇文章。ThreadLocalMap 中使用開放地址法來處理散列沖突,而 HashMap 中使用的分離鏈表法。之所以采用不同的方式主要是因為:在 ThreadLocalMap 中的散列值分散的十分均勻,很少會出現沖突。並且 ThreadLocalMap 經常需要清除無用的對象,使用純數組更加方便。

實現

我們知道 Map 是一種 key-value 形式的數據結構,所以在散列數組中存儲的元素也是 key-value 的形式。ThreadLocalMap 使用 Entry 類來存儲數據,下面是該類的定義:

static class Entry extends WeakReference <ThreadLocal <?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal <?> k, Object v) {
        super(k);
        value = v;
    }
}

Entry 將 ThreadLocal 實例作為 key,副本變量作為 value 存儲起來。注意 Entry 中對於 ThreadLocal 實例的引用是一個弱引用,該引用定義在 Reference 類(WeakReference的父類)中,下面是 super(k) 最終調用的代碼:

Reference(T referent) {
    this(referent, null);
}

Reference(T referent, ReferenceQueue <? super T> queue) {
    this.referent = referent;
    this.queue = (queue == null) ? ReferenceQueue.NULL : queue;
}

關於弱引用和為什么使用弱引用可以參考 Java 理論與實踐: 用弱引用堵住內存泄漏深入分析 ThreadLocal 內存泄漏問題。下面看一下 ThreadLocalMap 的 set 函數

private void set(ThreadLocal <?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    // 根據 ThreadLocal 的散列值,查找對應元素在數組中的位置
    int i = key.threadLocalHashCode & (len - 1);

    // 使用線性探測法查找元素
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal <?> k = e.get();
        // ThreadLocal 對應的 key 存在,直接覆蓋之前的值
        if (k == key) {
            e.value = value;
            return;
        }
        // key為 null,但是值不為 null,說明之前的 ThreadLocal 對象已經被回收了,當前數組中的 Entry 是一個陳舊(stale)的元素
        if (k == null) {
            // 用新元素替換陳舊的元素,這個方法進行了不少的垃圾清理動作,防止內存泄漏,具體可以看源代碼,沒看太懂
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // ThreadLocal 對應的 key 不存在並且沒有找到陳舊的元素,則在空元素的位置創建一個新的 Entry。
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // cleanSomeSlot 清理陳舊的 Entry(key == null),具體的參考源碼。如果沒有清理陳舊的 Entry 並且數組中的元素大於了閾值,則進行 rehash。
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

關於 set 方法,有幾點需要地方:

  • int i = key.threadLocalHashCode & (len - 1);,這里實際上是對 len-1 進行了取余操作。之所以能這樣取余是因為 len 的值比較特殊,是 2 的 n 次方,減 1 之后低位變為全 1,高位變為全 0。例如 16,減 1 之后對應的二進制為: 00001111,這樣其他數字中大於 16 的部分就會被 0 與掉,小於 16 的部分就會保留下來,就相當於取余了。
  • 在 replaceStaleEntry 和 cleanSomeSlots 方法中都會清理一些陳舊的 Entry,防止內存泄漏
  • threshold 的值大小為 threshold = len * 2 / 3;
  • rehash 方法中首先會清理陳舊的 Entry,如果清理完之后元素數量仍然大於 threshold 的 3/4,則進行擴容操作(數組大小變為原來的 2倍)
private void rehash() {
    expungeStaleEntries();
    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

我們再看一下 getEntry (沒有 get 方法,就叫 getEntry)方法:

private Entry getEntry(ThreadLocal <?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

因為 ThreadLocalMap 中采用開放定址法,所以當前 key 的散列值和元素在數組中的索引並不一定完全對應。所以在 get 的時候,首先會看 key 的散列值對應的數組元素是否為要查找的元素,如果不是,再調用 getEntryAfterMiss 方法查找后面的元素。

private Entry getEntryAfterMiss(ThreadLocal <?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal < ? > k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

最后看一下刪除操作。刪除其實就是將 Entry 的鍵值設為 null,變為陳舊的 Entry。然后調用 expungeStaleEntry 清理陳舊的 Entry。

private void remove(ThreadLocal <?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len - 1);
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

副本變量存取

前面說完了 ThreadLocalMap,副本變量的存取操作就很好理解了。下面是 ThreadLocal 中的 set 和 get 方法的實現:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T) e.value;
            return result;
        }
    }
    return setInitialValue();
}

存取的基本流程就是首先獲得當前線程的 ThreadLocalMap,將 ThreadLocal 實例作為鍵值傳入 Map,然后就是進行相關的變量存取工作了。線程中的 ThreadLocalMap 是懶加載的,只有真正的要存變量時才會調用 createMap 創建,下面是 createMap 的實現:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

如果想要給 ThreadLocal 的副本變量設置初始值,需要重寫 initialValue 方法,如下面的形式:

ThreadLocal <Integer> threadLocal = new ThreadLocal() {
    protected Integer initialValue() {
        return 0;
    }
};

ThreadLocal 散列值

當創建了一個 ThreadLocal 的實例后,它的散列值就已經確定了,下面是 ThreadLocal 中的實現:

/**
 * ThreadLocals rely on per-thread linear-probe hash maps attached
 * to each thread (Thread.threadLocals and
 * inheritableThreadLocals).  The ThreadLocal objects act as keys,
 * searched via threadLocalHashCode.  This is a custom hash code
 * (useful only within ThreadLocalMaps) that eliminates collisions
 * in the common case where consecutively constructed ThreadLocals
 * are used by the same threads, while remaining well-behaved in
 * less common cases.
 */
private final int threadLocalHashCode = nextHashCode();

/**
 * The next hash code to be given out. Updated atomically. Starts at
 * zero.
 */
private static AtomicInteger nextHashCode =
    new AtomicInteger();

/**
 * The difference between successively generated hash codes - turns
 * implicit sequential thread-local IDs into near-optimally spread
 * multiplicative hash values for power-of-two-sized tables.
 */
private static final int HASH_INCREMENT = 0x61c88647;

/**
 * Returns the next hash code.
 */
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

我們看到 threadLocalHashCode 是一個常量,它通過 nextHashCode() 函數產生。nextHashCode() 函數其實就是在一個 AtomicInteger 變量(初始值為0)的基礎上每次累加 0x61c88647,使用 AtomicInteger 為了保證每次的加法是原子操作。而 0x61c88647 這個就比較神奇了,它可以使 hashcode 均勻的分布在大小為 2 的 N 次方的數組里。下面寫個程序測試一下:

public static void main(String[] args) {
    AtomicInteger hashCode = new AtomicInteger();
    int hash_increment = 0x61c88647;
    int size = 16;
    List <Integer> list = new ArrayList <> ();
    for (int i = 0; i < size; i++) {
        list.add(hashCode.getAndAdd(hash_increment) & (size - 1));
    }
    System.out.println("original:" + list);
    Collections.sort(list);
    System.out.println("sort:    " + list);
}

我們將 size 設為 16,32 和 64 分別測試一下:

// size=16
original:[0, 7, 14, 5, 12, 3, 10, 1, 8, 15, 6, 13, 4, 11, 2, 9]
sort:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

// size=32
original:[0, 7, 14, 21, 28, 3, 10, 17, 24, 31, 6, 13, 20, 27, 2, 9, 16, 23, 30, 5, 12, 19, 26, 1, 8, 15, 22, 29, 4, 11, 18, 25]
sort:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]

// size=64
original:[0, 7, 14, 21, 28, 35, 42, 49, 56, 63, 6, 13, 20, 27, 34, 41, 48, 55, 62, 5, 12, 19, 26, 33, 40, 47, 54, 61, 4, 11, 18, 25, 32, 39, 46, 53, 60, 3, 10, 17, 24, 31, 38, 45, 52, 59, 2, 9, 16, 23, 30, 37, 44, 51, 58, 1, 8, 15, 22, 29, 36, 43, 50, 57]
sort:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]

可以看到隨着 size 的變化,hashcode 總能均勻的分布。其實這就是 Fibonacci Hashing,具體可以參考 這篇文章。所以雖然 ThreadLocal 的 hashcode 是固定的,當 ThreadLocalMap 中的散列表調整大小(變為原來的 2 倍)之后重新散列,hashcode 仍能均勻的分布在散列表中。

應用場景

摘自 Java並發編程:深入剖析ThreadLocal

最常見的ThreadLocal使用場景為 用來解決 數據庫連接、Session管理等。如

private static ThreadLocal < Connection > connectionHolder = new ThreadLocal < Connection > () {
    public Connection initialValue() {
        return DriverManager.getConnection(DB_URL);
    }
};

public static Connection getConnection() {
    return connectionHolder.get();
}
private static final ThreadLocal threadSession = new ThreadLocal();

public static Session getSession() throws InfrastructureException {
    Session s = (Session) threadSession.get();
    try {
        if (s == null) {
            s = getSessionFactory().openSession();
            threadSession.set(s);
        }
    } catch (HibernateException ex) {
        throw new InfrastructureException(ex);
    }
    return s;
}

參考文章


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM