ConcurrentBag可以理解為是一個線程安全無序集合,API比我們的list要弱一點,那我們來看看它的實現:
public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T> { // ThreadLocalList object that contains the data per thread ThreadLocal<ThreadLocalList> m_locals; // This head and tail pointers points to the first and last local lists, to allow enumeration on the thread locals objects volatile ThreadLocalList m_headList, m_tailList; bool m_needSync; public ConcurrentBag() { Initialize(null);} public ConcurrentBag(IEnumerable<T> collection) { if (collection == null) { throw new ArgumentNullException("collection", SR.GetString(SR.ConcurrentBag_Ctor_ArgumentNullException)); } Initialize(collection); } private void Initialize(IEnumerable<T> collection) { m_locals = new ThreadLocal<ThreadLocalList>(); // Copy the collection to the bag if (collection != null) { ThreadLocalList list = GetThreadList(true); foreach (T item in collection) { list.Add(item, false); } } } public void Add(T item) { // Get the local list for that thread, create a new list if this thread doesn't exist //(first time to call add) ThreadLocalList list = GetThreadList(true); AddInternal(list, item); } private void AddInternal(ThreadLocalList list, T item) { bool lockTaken = false; try { Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add); //Synchronization cases: // if the list count is less than two to avoid conflict with any stealing thread // if m_needSync is set, this means there is a thread that needs to freeze the bag if (list.Count < 2 || m_needSync) { // reset it back to zero to avoid deadlock with stealing thread list.m_currentOp = (int)ListOperation.None; Monitor.Enter(list, ref lockTaken); } list.Add(item, lockTaken); } finally { list.m_currentOp = (int)ListOperation.None; if (lockTaken) { Monitor.Exit(list); } } } private ThreadLocalList GetThreadList(bool forceCreate) { ThreadLocalList list = m_locals.Value; if (list != null) { return list; } else if (forceCreate) { // Acquire the lock to update the m_tailList pointer lock (GlobalListsLock) { if (m_headList == null) { list = new ThreadLocalList(Thread.CurrentThread); m_headList = list; m_tailList = list; } else { list = GetUnownedList(); if (list == null) { list = new ThreadLocalList(Thread.CurrentThread); m_tailList.m_nextList = list; m_tailList = list; } } m_locals.Value = list; } } else { return null; } Debug.Assert(list != null); return list; } public bool TryTake(out T result) { return TryTakeOrPeek(out result, true); } public bool TryPeek(out T result) { return TryTakeOrPeek(out result, false); } private bool TryTakeOrPeek(out T result, bool take) { // Get the local list for that thread, return null if the thread doesn't exit //(this thread never add before) ThreadLocalList list = GetThreadList(false); if (list == null || list.Count == 0) { return Steal(out result, take); } bool lockTaken = false; try { if (take) // Take operation { Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Take); //Synchronization cases: // if the list count is less than or equal two to avoid conflict with any stealing thread // if m_needSync is set, this means there is a thread that needs to freeze the bag if (list.Count <= 2 || m_needSync) { // reset it back to zero to avoid deadlock with stealing thread list.m_currentOp = (int)ListOperation.None; Monitor.Enter(list, ref lockTaken); // Double check the count and steal if it became empty if (list.Count == 0) { // Release the lock before stealing if (lockTaken) { try { } finally { lockTaken = false; // reset lockTaken to avoid calling Monitor.Exit again in the finally block Monitor.Exit(list); } } return Steal(out result, true); } } list.Remove(out result); } else { if (!list.Peek(out result)) { return Steal(out result, false); } } } finally { list.m_currentOp = (int)ListOperation.None; if (lockTaken) { Monitor.Exit(list); } } return true; } private bool Steal(out T result, bool take) { bool loop; List<int> versionsList = new List<int>(); // save the lists version do { versionsList.Clear(); //clear the list from the previous iteration loop = false; ThreadLocalList currentList = m_headList; while (currentList != null) { versionsList.Add(currentList.m_version); if (currentList.m_head != null && TrySteal(currentList, out result, take)) { return true; } currentList = currentList.m_nextList; } // verify versioning, if other items are added to this list since we last visit it, we should retry currentList = m_headList; foreach (int version in versionsList) { if (version != currentList.m_version) //oops state changed { loop = true; if (currentList.m_head != null && TrySteal(currentList, out result, take)) return true; } currentList = currentList.m_nextList; } } while (loop); result = default(T); return false; } private bool TrySteal(ThreadLocalList list, out T result, bool take) { lock (list) { if (CanSteal(list)) { list.Steal(out result, take); return true; } result = default(T); return false; } } private bool CanSteal(ThreadLocalList list) { if (list.Count <= 2 && list.m_currentOp != (int)ListOperation.None) { SpinWait spinner = new SpinWait(); while (list.m_currentOp != (int)ListOperation.None) { spinner.SpinOnce(); } } if (list.Count > 0) { return true; } return false; } /// <summary> /// Try to reuse an unowned list if exist /// unowned lists are the lists that their owner threads are aborted or terminated /// this is workaround to avoid memory leaks. /// </summary> /// <returns>The list object, null if all lists are owned</returns> private ThreadLocalList GetUnownedList() { //the global lock must be held at this point Contract.Assert(Monitor.IsEntered(GlobalListsLock)); ThreadLocalList currentList = m_headList; while (currentList != null) { if (currentList.m_ownerThread.ThreadState == System.Threading.ThreadState.Stopped) { currentList.m_ownerThread = Thread.CurrentThread; // the caller should acquire a lock to make this line thread safe return currentList; } currentList = currentList.m_nextList; } return null; } internal class ThreadLocalList { internal volatile Node m_head; private volatile Node m_tail; internal volatile int m_currentOp; private int m_count; internal int m_stealCount; internal volatile ThreadLocalList m_nextList; internal bool m_lockTaken; internal Thread m_ownerThread; internal volatile int m_version; internal ThreadLocalList(Thread ownerThread) { m_ownerThread = ownerThread; } internal void Add(T item, bool updateCount) { checked { m_count++; } Node node = new Node(item); if (m_head == null) { Debug.Assert(m_tail == null); m_head = node; m_tail = node; m_version++; // changing from empty state to non empty state } else { node.m_next = m_head; m_head.m_prev = node; m_head = node; } if (updateCount) // update the count to avoid overflow if this add is synchronized { m_count = m_count - m_stealCount; m_stealCount = 0; } } /// <summary> /// Remove an item from the head of the list /// </summary> /// <param name="result">The removed item</param> internal void Remove(out T result) { Debug.Assert(m_head != null); Node head = m_head; m_head = m_head.m_next; if (m_head != null) { m_head.m_prev = null; } else { m_tail = null; } m_count--; result = head.m_value; } /// <summary> /// Peek an item from the head of the list /// </summary> /// <param name="result">the peeked item</param> /// <returns>True if succeeded, false otherwise</returns> internal bool Peek(out T result) { Node head = m_head; if (head != null) { result = head.m_value; return true; } result = default(T); return false; } internal void Steal(out T result, bool remove) { Node tail = m_tail; Debug.Assert(tail != null); if (remove) // Take operation { m_tail = m_tail.m_prev; if (m_tail != null) { m_tail.m_next = null; } else { m_head = null; } // Increment the steal count m_stealCount++; } result = tail.m_value; } } internal class Node { public Node(T value) { m_value = value; } public readonly T m_value; public Node m_next; public Node m_prev; } }
首先我們需要知道里面有2個內部類Node和ThreadLocalList都是鏈表結構,其中Node是雙向鏈表,因為它有m_next和m_prev屬性,但是ThreadLocalList確是單項鏈表只有m_nextList屬性,ThreadLocalList是Node的集合,有m_head和m_tail屬性指向Node實例。現在我們來看ConcurrentBag的幾個變量,ThreadLocal<ThreadLocalList> m_locals表示當前線程的list,所以從這里我們可以猜測線程安全是采用ThreadLocal來實現的。 volatile ThreadLocalList m_headList, m_tailList;這2個變量應該是可以遍歷所有線程的list。
無論是初始化Initialize方法還是添加元素的Add方法,我們首先要調用GetThreadList放來獲取當前線程的list,GetThreadList方法 首先檢查當前線程的m_locals.Value是否存在,有則直接返回;否者檢查當前線程是否是程序第一個線程【m_headList == null】,如果是則創建新的ThreadLocalList,否者調用GetUnownedList放法檢查是否有孤立ThreadLocalList使用【ThreadLocalList的邏輯線程已經停止,但是該ThreadLocalList實例確存在】,如果有則返回改ThreadLocalList,否則只有新建ThreadLocalList實例。
現在看看AddInternal方法的實現,首先修改ThreadLocalList的m_currentOp標記為添加元素【 Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add)】,然后在添加元素 list.Add(item, lockTaken);,如果該list需要lock的話,那么在添加元素前我們還需要加鎖Monitor.Enter(list, ref lockTaken),添加后需要解鎖 Monitor.Exit(list)。ThreadLocalList的Add方法非常簡單,把新節點放到鏈表頭部【 node.m_next = m_head;m_head.m_prev = node; m_head = node;】
添加元素時添加到各個線程的ThreadLocalList,那么讀取就比較麻煩了,我們需要讀取各各線程ThreadLocalList的數據,也就是說需要用到m_headList, m_tailList兩個變量。如果當前線程存在ThreadLocalList實例,那么直接從ThreadLocalList里面拿去數據,如果需要加鎖,那么我們就加鎖【 Monitor.Enter(list, ref lockTaken)】和解鎖【Monitor.Exit(list)】,都是當前線程的list,如果當前線程ThreadLocalList不存在,或者沒有數據,我們需要從其他線程的ThreadLocalList獲取數據,Steal方法 首先或從m_headList開始,依次遍歷每一個ThreadLocalList,然后從它們里面獲取數據,如果獲取不到數據,那么就再次遍歷一下所有的ThreadLocalList,檢查哪些ThreadLocalList的版本m_version在這兩次遍歷過程中發生了變化。
do { versionsList.Clear(); //clear the list from the previous iteration loop = false; ThreadLocalList currentList = m_headList; while (currentList != null) { versionsList.Add(currentList.m_version); if (currentList.m_head != null && TrySteal(currentList, out result, take)) { return true; } currentList = currentList.m_nextList; } // verify versioning, if other items are added to this list since we last visit it, we should retry currentList = m_headList; foreach (int version in versionsList) { if (version != currentList.m_version) //oops state changed { loop = true; if (currentList.m_head != null && TrySteal(currentList, out result, take)) return true; } currentList = currentList.m_nextList; } } while (loop);
TrySteal方法的實現就非常簡單了,檢查list是否可以查詢數據【CanSteal(list)】,CanSteal里面也用了自旋來實現【if (list.Count <= 2 && list.m_currentOp != (int)ListOperation.None){ SpinWait spinner = new SpinWait(); while (list.m_currentOp != (int)ListOperation.None) {spinner.SpinOnce(); } }】,真正Steal實現是由ThreadLocalList來做的,比較簡單。