Lua table.sort()原理和使用的坑


參考博客:lahmiley

最近使用table.sort()的時候遇到了一個報錯的問題:invalid order function for sorting。
感覺很奇怪,於是總結下方法的原理和報錯的原因。

先討論下lua里面sort的實現:

table.sort原理和內部實現

  • table.sort的內部使用的是快排,並對其做了三點優化。

  • 刷題的時候可能我們寫的快排大部分會直接使用數組開頭作為基點,但是這樣的話,當我們遇到數組已經是排好序的情況的時候,快排會退化為冒泡,時間復雜度升到了On^2,會很耗費性能。所以這里對其進行了優化,使用數組的開頭、中間、結尾中間大的元素作為基點,減少特殊情況時快排的次數。

  • 對於迭代函數中的分割數組長度小於等於3的,直接通過比較大小交換位置,形成有序數組,這樣做的目的是減少遞歸調用的深度。

  • 每次通過錨點把分割數組分成兩半之后,對長度較小的一半進行遞歸調用,另一半則繼續通過While繼續分割處理,目的應該也是減少遞歸調用的深度

table.sort源碼

    static void auxsort (lua_State *L, int l, int u) {
        while (l < u) {  /* for tail recursion */
            int i, j;
            /* sort elements a[l], a[(l+u)/2] and a[u] */
            lua_rawgeti(L, 1, l);
            lua_rawgeti(L, 1, u);
            
            if (sort_comp(L, -1, -2))  /* a[u] < a[l]? */
                set2(L, l, u);  /* swap a[l] - a[u] */
            else
                lua_pop(L, 2);
            
            if (u-l == 1) 
                break;  /* only 2 elements */

            i = (l+u)/2;
            lua_rawgeti(L, 1, i);
            lua_rawgeti(L, 1, l);

            if (sort_comp(L, -2, -1))  /* a[i]<a[l]? */
                set2(L, i, l);
            else {
                lua_pop(L, 1);  /* remove a[l] */
                lua_rawgeti(L, 1, u);
                if (sort_comp(L, -1, -2))  /* a[u]<a[i]? */
                    set2(L, i, u);
                else
                    lua_pop(L, 2);
            }

            if (u-l == 2) 
                break;  /* only 3 elements */

            lua_rawgeti(L, 1, i);  /* Pivot */
            lua_pushvalue(L, -1);
            lua_rawgeti(L, 1, u-1);
            set2(L, i, u-1);
        
            //上面代碼是對分割數組長度小於等於3的進行比較和排序
            //並從數組初始點、中間點、結尾點中選擇中位數作為錨點
        
            /* a[l] <= P == a[u-1] <= a[u], only need to sort from l+1 to u-2 */
            i = l; j = u-1;
            for (;;) {  /* invariant: a[l..i] <= P <= a[j..u] */
                /* repeat ++i until a[i] >= P */
                while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
                    if (i>u) 
                        luaL_error(L, "invalid order function for sorting");
                    lua_pop(L, 1);  /* remove a[i] */
                }
                /* repeat --j until a[j] <= P */
                while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
                    if (j<l) 
                        luaL_error(L, "invalid order function for sorting");
                    lua_pop(L, 1);  /* remove a[j] */
                }
                if (j<i) {
                    lua_pop(L, 3);  /* pop pivot, a[i], a[j] */
                    break;
                }
                set2(L, i, j);
            }
            lua_rawgeti(L, 1, u-1);
            lua_rawgeti(L, 1, i);
            set2(L, u-1, i);  /* swap pivot (a[u-1]) with a[i] */
        
            //上面代碼是快排算法,依據選擇的錨點把小於錨點的放在一邊,大於錨點的放在另一邊
        
            /* a[l..i-1] <= a[i] == P <= a[i+1..u] */
            /* adjust so that smaller half is in [j..i] and larger one in [l..u] */
            if (i-l < u-i) {
                j=l; 
                i=i-1; 
                l=i+2;
            }
            else {
                j=i+1; 
                i=u; 
                u=j-2;
            }
            auxsort(L, j, i);  /* call recursively the smaller one */
        }  /* repeat the routine for the larger one */
        
        //上面代碼是讓分割后長度較短的數組繼續迭代,長度較長的則繼續通過while進行快排算法,減少遞歸調用的次數
    }

報錯的位置源碼:

    while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
        if (i>u) 
            luaL_error(L, "invalid order function for sorting");
        lua_pop(L, 1);  /* remove a[i] */
    }
    /* repeat --j until a[j] <= P */
    while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
        if (j<l) 
            luaL_error(L, "invalid order function for sorting");
        lua_pop(L, 1);  /* remove a[j] */
    }

報錯原因和解決

報錯出現的條件:當基點和數組邊界值有相等時,這時如果排序方法sort_comp返回true,則會造成數組越界。

local array = {9,15,9,222,10}

--此時的基點是9,和開頭元素相等
table.sort(array, function(a, b)
    --遞增
    return a <= b
end)
--報錯:invalid order function for sorting

報錯解決:當比較的兩個值相等時,返回false即可。

local array = {9,15,9,222,10}

table.sort(array, function(a, b)
    return a < b
end)

補充(當數組中有為nil的元素時)

sort排序的表必須是從1-n連續的,不能有nil。
不然的話,排序會把nil的前一個元素當做尾元素來進行排序。

local array = {9,15,9,nil,10}

table.sort(array, function(a, b)
    return a < b
end)

for i = 1, #array do
    print(array[i])
end
--[[ 輸出:
    9 9 15
]]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM