概述
集合類中的sort方法,聽說在java7中就引入了,但是我沒有用過java7,不太清楚,java8中的排序是采用Timsort排序算法實現的,這個排序最開始是在python中由Tim Peters實現的,后來Java覺得不錯,就引入了這個排序到Java中,竟然以作者的名字命名,搞得我還以為這個Tim是一個單詞的意思,了不起,本文就從Arrays中實現的排序分析一下這個排序算法的原理,本文只會從源碼角度分析,不會從算法角度去分析。
進入List中查看sort方法源碼如下:
default void sort(Comparator<? super E> c) { Object[] a = this.toArray(); // 這個方法很簡單,就是調用Arrays中的sort方法進行排序 Arrays.sort(a, (Comparator) c); ListIterator<E> i = this.listIterator(); for (Object e : a) { i.next(); i.set((E) e); } }
進入Arrays.sort()方法
public static <T> void sort(T[] a, Comparator<? super T> c) { //這個是自己傳入的比較器如果為空,這里的a為存放數據的數組,那數組中的的元素都必須實現Comparator接口 if (c == null) { sort(a); } else {
//這個判斷沒有細看,不太清楚在判斷什么 if (LegacyMergeSort.userRequested) legacyMergeSort(a, c); else //這里就是所謂的TimSort了 TimSort.sort(a, 0, a.length, c, null, 0, 0); } }
由於sort()和TimSort.sort走的流程基本一致,這里只分析TimSort.sort()方法,進入該方法。
這里有必要先說一下TimSort排序算法的核心內容,了解這個算法的核心內容有助於看下面的代碼。
TimSort的核心是這樣:
1.如果數組的長度小於32,直接采用二分法插入排序,就是下面方法中的binarySort方法實現的,這個算法原理,我舉個例子大家就明白了
假設數組為:[1,3,9,6,2],二分法插入排序插入如下:
I:從開頭先把自然升序(或降序)段找出來,那什么是自然升序段,就是沒有經過排序算法,原始數據就是有序的,本數組中自然升序段就是1,3,9
II:按照正常的思維,直接拿6,從頭開始和前三個元素一個一個比較也可以實現排序,但是這樣效率太低,那怎么做可以效率高點呢?就是我們之前高數中學的二分查找法,就是通過二分查找法,我先找到3,發現6 > 3,那我就不用和1進行比較了。
2.如果數組的長度大於32,那就把數組拆分成一個一個的小段,每段的長度在16~32之間,使用上面介紹的二分法插入排序,把每一段進行排序,之后在把每一個排好序的段進行合並,最終就可以實現整個數組的排序,大致的思想就是這樣,這個敘述可能會給大家一種誤解,就是認為每一段都排好序之后在進行合並,其實不是這樣的,而是每一段邊排序,如果符合特定條件就會合並。
有了上面的了解,我們再來看下面的代碼
static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c, T[] work, int workBase, int workLen) { assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length; //這里是數組中剩余沒有排序的元素個數,初始長度為數組的長度 int nRemaining = hi - lo; if (nRemaining < 2) return; // Arrays of size 0 and 1 are always sorted //這里的MIN_MERGE就是32,如果數組長度小於32,直接采用二分法插入排序 // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { int initRunLen = countRunAndMakeAscending(a, lo, hi, c); binarySort(a, lo, hi, lo + initRunLen, c); return; } /** * March over the array once, left to right, finding natural runs, * extending short natural runs to minRun elements, and merging runs * to maintain stack invariant. */ //這個是TimSort核心類,很多處理邏輯都是這個里面 TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen); //<1.1> 這個就是計算分割之后每一段的長度的 int minRun = minRunLength(nRemaining); do { // Identify next run //<1.2> 尋找自然增長的結束位置 int runLen = countRunAndMakeAscending(a, lo, hi, c); // If run is short, extend to min(minRun, nRemaining) if (runLen < minRun) { int force = nRemaining <= minRun ? nRemaining : minRun; //<1.3> 對每一段進行二分法插入排序 binarySort(a, lo, lo + force, lo + runLen, c); runLen = force; } // Push run onto pending-run stack, and maybe merge,將每一段的起始位置和每一段的分段長度放入棧中 ts.pushRun(lo, runLen); //<1.4> 合並排好序的段,這個就是上面我說的並不是等所有的都排好序了再合並 ts.mergeCollapse(); // Advance to find next run lo += runLen;
// 將已經排好序的段,從總長度中減去 nRemaining -= runLen; } while (nRemaining != 0); // Merge all remaining runs to complete sort assert lo == hi; //<1.5> 最終排序,這個方法在整個排序中只會執行一次 ts.mergeForceCollapse(); assert ts.stackSize == 1; }
上面注釋中<1.1>, minRunLength(nRemaining)
private static int minRunLength(int n) { assert n >= 0; int r = 0; // Becomes 1 if any 1 bits are shifted off while (n >= MIN_MERGE) {
//這一段有點繞,其實意思就是當n是奇數的時候r = 1 r |= (n & 1);
//這段的意思就是一直除於2,直到n < 32為止 n >>= 1; }
//如果n是偶數,結果就是一直除於2的結果,如果是奇數,就是一直除於二加一 return n + r; }
上面注釋<1.2>, countRunAndMakeAscending(a, lo, hi, c);
//解釋一下各個參數:a就是存放元素的數組,lo第一個元素的位置(注意這個第一個元素並不一定是數組的第一個元素的位置,而是每一段的第一個元素),hi表示數組的長度,c為比較器 private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi, Comparator<? super T> c) { assert lo < hi; int runHi = lo + 1; if (runHi == hi) return 1; // Find end of run, and reverse range if descending //下面的if...else就是尋找自增序列的,if中判斷的情況是尋找自然降序的 if (c.compare(a[runHi++], a[lo]) < 0) { // Descending while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0) runHi++;
//找到降序的段之后,進行反轉成升序 reverseRange(a, lo, runHi); } else { // Ascending,前面的英文是原來的注釋,可以看到這個是找升序段的位置的 while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0) runHi++; } //最后返回的結果其實第一個違反將序或升序的元素位置減去第一個元素的位置,舉例[1,3,5,2,4],那么runHi=3,lo = 0,最后返回3 return runHi - lo; }
這個方法其實就是上面我在說TimSort原理的時候講到的,尋找自然升序或將序段,這么做的原因是減小排序時候元素的個數,加快排序速度。
上面注釋<1.3>,binarySort(a, lo, lo + force, lo + runLen, c);這個方法是核心排序方法,使用的是二分法插入排序算法
//先解釋一下各個參數:a為存放元素的數組,lo是各個分段的起始位置,hi為數組的長度,start就是coutRunAndMakeAsending()方法返回的結果加上起始結果 private static <T> void binarySort(T[] a, int lo, int hi, int start, Comparator<? super T> c) { assert lo <= start && start <= hi; if (start == lo) start++; for ( ; start < hi; start++) { // 備份start位置的值,因為這個直后面可能被覆蓋掉 T pivot = a[start]; // Set left (and right) to the index where a[start] (pivot) belongs int left = lo; int right = start; assert left <= right; /* * Invariants: * pivot >= all in [lo, left). * pivot < all in [right, start). */ //下面這個while循環就是一個二分查找法的過程,先確定二分查找法的范圍,就是left和rigth,之后每次找到left和rigth的中間點,那后比較 while (left < right) {
//尋找中間點 int mid = (left + right) >>> 1;
//比較大小 if (c.compare(pivot, a[mid]) < 0) right = mid; else left = mid + 1; } assert left == right; /* * The invariants still hold: pivot >= all in [lo, left) and * pivot < all in [left, start), so pivot belongs at left. Note * that if there are elements equal to pivot, left points to the * first slot after them -- that's why this sort is stable. * Slide elements over to make room for pivot. */ int n = start - left; // The number of elements to move // Switch is just an optimization for arraycopy in default case ,這個switch case用的非常講究,當你明白了這個玩意,你就不得不佩服大佬,看看真正的大佬是如何把普通的東西玩出不一樣 switch (n) { case 2: a[left + 2] = a[left + 1]; case 1: a[left + 1] = a[left]; break; default: System.arraycopy(a, left, a, left + 1, n); } a[left] = pivot; } }
這個方法總的來說還是很好懂的,就是switch case那一塊用的很牛逼,下面我就說一下這一塊。我會對case對應的每一種情況舉一個例子大家就明白這里為什么吊了。
case n = 2: 假設a = [1,2,5,8,9,6],當使用二分查找法定位的時候一定可以定位到5后面的位置,也就是a[3],這個時候要怎么做呢,這時left = 3,那a[left + 2] = a[left + 1];就是a[5] = a[4],就是把9放到6的位置。這還沒有完,因為這個case后面沒有continue也沒有break,會繼續執行后面的case,也就是會執行a[left + 1] = a[left];就是a[4] = a[3],相當於把8放到9的位置,這個時候來了一個break,后面的default就不會執行了,但是會執行a[left] = pivot;這句相當於a[3] = 6,這樣就實現了一個交換動作(如果對switch...case中使用break不熟悉的建議先查一下這個)。
case n = 1: 假設a = [1,2,5,4],這時left = 2,執行a[left + 1] = a[left]相當於a[3] = a[2] = 5;最后執行a[left] = pivot; 就是a[2] = 4,這樣做完之后就變成了[1,2,4,5]
default: 就是n > 2的時候,思考一下這里n不可能為0,因為n為0說明就是正常的升序,這個在前一個方法尋找自然自增序列的時候已經處理了這種情況,那就是當n > 2時候,舉例如下:
a = [1,3,4,4,5,2],這種時候n = 4如果還是一個元素一個元素移動,那效率太低了,這時候使用了System.arraycopy方法,關於這個方法我寫個例子大家看一下就知道這個方法可以干什么了。
public static void copyArray(){ Integer[] a = {1,3,4,4,5,2};
//關於這個參數大家可以看一下源碼 System.arraycopy(a,1,a,2,4); for (int i = 0; i < a.length; i++) { System.out.println(a[i]); } } public static void main(String[] args) { copyArray(); }
輸出結果為:1 3 3 4 4 5
可以看出這個方法的作用其實就是把數組元素下標1到4的元素拷貝到2到5,最后一個參數4表示拷貝的元素個數,可以發現最后一個元素2被覆蓋了,細心的朋友可能發現這個結果也不是我們想要的排序結果啊,我們想要的是1 2 3 4 4 5,而現在是1 3 3 4 4 5,別急不是還有a[left] = pivot;這個方法嗎,這個方法就是a[1] = 2,這樣就完美了。簡直是小母牛去南極,牛逼到了極點。
上面注釋1.4,ts.mergeCollapse(),合並排好序的分段,再說這個方法之前有必要先說一下TimSort方法定義的幾個屬性,這幾個屬性會在下面的方法中用到
class TimSort<T> { private static final int MIN_GALLOP = 7; private T[] tmp; // private int tmpBase; // base of tmp array slice private int tmpLen; // length of tmp array slice private int stackSize = 0; // Number of pending runs on stack //這個里面存放就是每段待合並的分段的在數組的開始位置 private final int[] runBase; //這里面存放的是每個分段的長度,和上面的runBase是一一對應的 private final int[] runLen; }
ok,有了上面的認識,來開始看下面的代碼
private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; //這個if判斷的意義就是后面兩段的長度之和一定要大於前面一段的長度才會執行 if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { if (runLen[n - 1] < runLen[n + 1]) n--;
// <2.1>執行合並邏輯 mergeAt(n); //如果上面的不滿足,會執行這個,這個意思就是后面一段的長度要大於前面一段的長度 } else if (runLen[n] <= runLen[n + 1]) { mergeAt(n); //否則就不進行合並 } else { break; // Invariant is established } } }
這個方法主要的作用就是為了防止太短的段和比較長的段進行合並,浪費時間,舉個例子現在runLen = [256 128 20 8],如果讓8和20合並還可以接受,但是讓20和128合並就太浪費時間了,至於說小段和大段合並為什么比較浪費,后面會分析,不過這里要做一個補充,其實在這里即便不做合並,后面還是會合並的,只是后面那個合並是一次性的,就是一個方法把所有的未合並分段全部合並了。
注釋<2.1>,下面進入mergeAt(n);
private void mergeAt(int i) { assert stackSize >= 2; assert i >= 0; assert i == stackSize - 2 || i == stackSize - 3; int base1 = runBase[i]; int len1 = runLen[i]; int base2 = runBase[i + 1]; int len2 = runLen[i + 1]; assert len1 > 0 && len2 > 0; assert base1 + len1 == base2; /* * Record the length of the combined runs; if i is the 3rd-last * run now, also slide over the last run (which isn't involved * in this merge). The current run (i+1) goes away in any case. */ runLen[i] = len1 + len2; if (i == stackSize - 3) { runBase[i + 1] = runBase[i + 2]; runLen[i + 1] = runLen[i + 2]; } stackSize--; /* * Find where the first element of run2 goes in run1. Prior elements * in run1 can be ignored (because they're already in place). */ // int k = gallopRight(a[base2], a, base1, len1, 0, c); assert k >= 0; base1 += k; len1 -= k; if (len1 == 0) return; /* * Find where the last element of run1 goes in run2. Subsequent elements * in run2 can be ignored (because they're already in place). */ // len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c); assert len2 >= 0; if (len2 == 0) return; // Merge remaining runs, using tmp array with min(len1, len2) elements if (len1 <= len2) mergeLo(base1, len1, base2, len2); else
//<3.1> 合並
mergeHi(base1, len1, base2, len2); }
這個方法其實也沒有執行合並的邏輯,那這個方法在干啥呢?其實這個方法還是在縮短比較的段的長度,其中兩個主要的方法就是gallopRigth()和gallopLeft(),這個兩個方法是在干啥呢?我下面還是舉例說明吧。
假設:
第一個段為:[1,2,3,5,6,8,9]
第二個段為:[4,6,7,8,10,11,12]
第一個段在數組中的位置在第二個段前面,這里注意實際的段不可能這么短,上面有說段的長度應該在16到32之間,這里只是舉例為了說明問題。
gallopRigth(): 尋找第二段的第一個元素在第一段中的位置,比如例子中位置為2,那也就是說第一段的前兩個元素沒有必要參與合並,他的位置不用動。
gallopLeft(): 尋找第一段的結尾元素在第二段中的位置,這里發現在第二段的第4,那也就是說第二段的10,11,12沒有必要參與合並,同樣是位置不需要改動。
最終參與合並的段為:
第一段:[5,6,8,9]
第二段:[4, 6, 7, 8]
這樣參與合並的段的長度就大大減小,時間相應的就變短了,可能細心的小伙伴到這里就有一個疑問了,gallopRigth()是尋找第二段的第一個元素在第一段中的位置,而不是反過來,我覺得能想到這個疑問的朋友應該稍微思考一下就知道這個原因了,我就不多嘴了。所以gallopRigth()和gallopLeft()的源碼我就不分析了,有興趣的可以自己去看,里面還有些細節沒有寫到。
注釋<3.1>,mergeHi(base1, len1, base2, len2);這個方法和注釋<3.1>,mergeLo(base1, len1, base2, len2);是類似的方法,分析其中一個就可以了
//解釋一下參數: base1 = 第一段的開始位置,len1 = 第一段的長度,base2 = 第二段的開始位置,len2 = 第二段的長度
private void mergeHi(int base1, int len1, int base2, int len2) { assert len1 > 0 && len2 > 0 && base1 + len1 == base2; // Copy second run into temp array T[] a = this.a; // For performance
//這里建立了一個空數組,目的是為了存放第二段的數據 T[] tmp = ensureCapacity(len2);
//這里的temBase是TimSort定義的一個屬性,在TimSort初始化的時候給了一個初始化值0 int tmpBase = this.tmpBase;
//這里就是給上面新建的空數組賦值的,就是第二段放入到這個臨時數組中 System.arraycopy(a, base2, tmp, tmpBase, len2); //下面定義兩個游標,控制每一段比較的位置 int cursor1 = base1 + len1 - 1; // Indexes into a int cursor2 = tmpBase + len2 - 1; // Indexes into tmp array
//這個就是第二段的結束位置,這兩段進行比較的時候也是從末尾開始比較,這里就是記錄兩段中比較大的元素會放入到這個位置上 int dest = base2 + len2 - 1; // Indexes into a // Move last element of first run and deal with degenerate cases
//從這一句就可以看出實現這個的作者的細致,解釋一下這句,cursor1是第一段的結束位置,dest是第二段的結束位置,第一段的結束位置的值一定大於第二段的結束位置的值
//至於原因就是上面我分析的gallopLeft()方法和gollopRight() a[dest--] = a[cursor1--];
//這一句的意思就是如果len1 = 1,說明第二段都應該放入到第一段這個值的前面,原因還是上面的那個原因 if (--len1 == 0) { System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2); return; }
//這一句的意思就是如果第二段的長度為1,那就把他放入到第一段的前面 if (len2 == 1) { dest -= len1; cursor1 -= len1; System.arraycopy(a, cursor1 + 1, a, dest + 1, len1); a[dest] = tmp[cursor2]; return; } Comparator<? super T> c = this.c; // Use local variable for performance int minGallop = this.minGallop; // 這個minGallop = 7,是默認值,為啥搞這個值,看下面就知道了 outer: while (true) {
int count1 = 0; // Number of times in a row that first run won,其實這個count1和count2非常有意思,就是記錄第一段中連續比第二段大的數的個數,注意是連續 int count2 = 0; // Number of times in a row that second run won,這個就是記錄第二段中連續比第一段大的數字的個數 /* * Do the straightforward thing until (if ever) one run * appears to win consistently.
* 下面有兩個do...while循環,其實是可以使用一個do...while循環實現的,但是作者為了優化,搞了兩個do...while循環,下面我先說一下這個循環干了啥,為什么要搞兩個
* 使用do...while循環的作用就是分別從第一段的最后一個數和第二段的最后一個數做比較,比較大小之后,誰比較大就放在dist的位置,這個位置其實就是從第二段結尾的位置逐漸減小
* 接下來說一下為什么要搞兩個do...while循環,如果在這個比較當中發現count1 > 7或者count2 > 7,說明什么,說明第一段中有連續7個值大於第二段的未被比較的最大值,那就說明可能存在更多的值
* 大於第二段中的未被比較的值中的最大值,所以呢,他又調用了gallopLeft()和gallopRigth()把不需要比較的找出來進一步縮短合並的段的大小
* 看了上看的分析,是不是覺得作者非常牛逼,簡直是小母牛掉進酒缸,醉牛逼 */ do { assert len1 > 0 && len2 > 1;
//比較第一段的最后一個元素和第二段的最后一個元素的大小 if (c.compare(tmp[cursor2], a[cursor1]) < 0) { a[dest--] = a[cursor1--]; count1++; count2 = 0; if (--len1 == 0) break outer; } else { a[dest--] = tmp[cursor2--]; count2++; count1 = 0; if (--len2 == 1) break outer; }
//如果count1 >= 7 或者count2 = 7,跳出循環,進入下一個循環中,把不需要合並的剔除 } while ((count1 | count2) < minGallop); /* * One run is winning so consistently that galloping may be a * huge win. So try that, and continue galloping until (if ever) * neither run appears to be winning consistently anymore.
* 這個do...while循環有點神奇,就是當count1>=7或者count2 >=7的時候就在這個里面實現合並,而不重新跳回第一個do...while循環
*
* */ do { assert len1 > 0 && len2 > 1;
//是不是又看到這個熟悉的方法 count1 = len1 - gallopRight(tmp[cursor2], a, base1, len1, len1 - 1, c); if (count1 != 0) { dest -= count1; cursor1 -= count1; len1 -= count1; System.arraycopy(a, cursor1 + 1, a, dest + 1, count1); if (len1 == 0) break outer; }
//這個do...while循環的合並就是采用這一句,每次這個合並完之后,重新去執行gallopRight或者gallopLeft方法,重新把不用合並的剔除掉 a[dest--] = tmp[cursor2--]; if (--len2 == 1) break outer; count2 = len2 - gallopLeft(a[cursor1], tmp, tmpBase, len2, len2 - 1, c); if (count2 != 0) { dest -= count2; cursor2 -= count2; len2 -= count2; System.arraycopy(tmp, cursor2 + 1, a, dest + 1, count2); if (len2 <= 1) // len2 == 1 || len2 == 0 break outer; }
//這個和上面類似,就是使用這個進行合並 a[dest--] = a[cursor1--]; if (--len1 == 0) break outer;
//這里的minGallop初始值是7,在這循環中,每循環一次就減1 minGallop--;
//如果發現第一段或者第二段的長度小於7了,就跳出這個循環 } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); if (minGallop < 0) minGallop = 0;
//重新給minGallop一個新值,跳回到第一個do...while中 minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field if (len2 == 1) { assert len1 > 0; dest -= len1; cursor1 -= len1; System.arraycopy(a, cursor1 + 1, a, dest + 1, len1); a[dest] = tmp[cursor2]; // Move first elt of run2 to front of merge } else if (len2 == 0) { throw new IllegalArgumentException( "Comparison method violates its general contract!"); } else { assert len1 == 0; assert len2 > 0; System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2); } }
總結:其實合並的過程就是在這個兩個do...while之間來回跳的過程,而第二個do...while循環其實是對合並的一個優化,即便沒有第二個循環也可以完成合並操作,不過要修改一下第一個循環的條件,而第二個循環是怎么優化的呢?這里就是作者的一個重要的思考了,就是當第二段的值連續大於第一段的某個值7次,是不是可以認為第二段中有可能有更多的值大於第一段呢?我覺得這個推斷完全是正確的,做了這個優化之后就可以減少很多需要合並的值,這就是作者的厲害之處。
分析完以上合並過程,其實並沒有完,為什么?因為上面1.4中說,並不是隨便兩個相鄰的段都可以合並,而要滿主一定的條件才可以合並,滿足什么條件呢?其實上面已經說了,這里在重復一遍。
假設:連續的三段的長度x,y,z只要滿足如下條件就合並:
x <= y + z. || y <=z
那相反的就是:
x > y+z. && y >z
也就是說滿足上面條件的段就不會使用上面的方法進行合並,那這些個沒有合並的段在哪里合並的呢,在下面的代碼中合並。
注釋<1.5>,ts.mergeForceCollapse();執行最終的合並
private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; mergeAt(n); } }
這個里面stackSize就是棧的深度,其實就是沒有合並的的段的多少,如果stackSize = 1說明什么,說明就只有一段了,那就說明已經合並完成了,至於mergeAt(n),這個方法我在上面已經介紹過了。
最后總結
這個排序算法其實還是很有必要看一下的,寫的很有意思,里面做了大量的優化,從這些優化中我們可以學習到很多的東西,學到什么東西呢?可以看到這些大佬是怎么思考的,是怎么做事情的,在看這個代碼的過程中我發現實現這個代碼的作者真是非常的細致,每個能優化的點都考慮的非常清楚,簡直是小母牛做鋼鋸,巨牛逼。另外在參考文章我也推薦大家好好看看,這篇文章的作者也很有意思,里面基本把上面的代碼的過程給寫出來了,只是沒有把代碼貼出來,感謝前輩。
參考文章: