
參考博客:lahmiley
最近使用table.sort()的時候遇到了一個報錯的問題:invalid order function for sorting。
感覺很奇怪,於是總結下方法的原理和報錯的原因。
先討論下lua里面sort的實現:
table.sort原理和內部實現
-
table.sort的內部使用的是快排,並對其做了三點優化。
-
刷題的時候可能我們寫的快排大部分會直接使用數組開頭作為基點,但是這樣的話,當我們遇到數組已經是排好序的情況的時候,快排會退化為冒泡,時間復雜度升到了On^2,會很耗費性能。所以這里對其進行了優化,使用數組的開頭、中間、結尾中間大的元素作為基點,減少特殊情況時快排的次數。
-
對於迭代函數中的分割數組長度小於等於3的,直接通過比較大小交換位置,形成有序數組,這樣做的目的是減少遞歸調用的深度。
-
每次通過錨點把分割數組分成兩半之后,對長度較小的一半進行遞歸調用,另一半則繼續通過While繼續分割處理,目的應該也是減少遞歸調用的深度
table.sort源碼
static void auxsort (lua_State *L, int l, int u) {
while (l < u) { /* for tail recursion */
int i, j;
/* sort elements a[l], a[(l+u)/2] and a[u] */
lua_rawgeti(L, 1, l);
lua_rawgeti(L, 1, u);
if (sort_comp(L, -1, -2)) /* a[u] < a[l]? */
set2(L, l, u); /* swap a[l] - a[u] */
else
lua_pop(L, 2);
if (u-l == 1)
break; /* only 2 elements */
i = (l+u)/2;
lua_rawgeti(L, 1, i);
lua_rawgeti(L, 1, l);
if (sort_comp(L, -2, -1)) /* a[i]<a[l]? */
set2(L, i, l);
else {
lua_pop(L, 1); /* remove a[l] */
lua_rawgeti(L, 1, u);
if (sort_comp(L, -1, -2)) /* a[u]<a[i]? */
set2(L, i, u);
else
lua_pop(L, 2);
}
if (u-l == 2)
break; /* only 3 elements */
lua_rawgeti(L, 1, i); /* Pivot */
lua_pushvalue(L, -1);
lua_rawgeti(L, 1, u-1);
set2(L, i, u-1);
//上面代碼是對分割數組長度小於等於3的進行比較和排序
//並從數組初始點、中間點、結尾點中選擇中位數作為錨點
/* a[l] <= P == a[u-1] <= a[u], only need to sort from l+1 to u-2 */
i = l; j = u-1;
for (;;) { /* invariant: a[l..i] <= P <= a[j..u] */
/* repeat ++i until a[i] >= P */
while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
if (i>u)
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[i] */
}
/* repeat --j until a[j] <= P */
while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
if (j<l)
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[j] */
}
if (j<i) {
lua_pop(L, 3); /* pop pivot, a[i], a[j] */
break;
}
set2(L, i, j);
}
lua_rawgeti(L, 1, u-1);
lua_rawgeti(L, 1, i);
set2(L, u-1, i); /* swap pivot (a[u-1]) with a[i] */
//上面代碼是快排算法,依據選擇的錨點把小於錨點的放在一邊,大於錨點的放在另一邊
/* a[l..i-1] <= a[i] == P <= a[i+1..u] */
/* adjust so that smaller half is in [j..i] and larger one in [l..u] */
if (i-l < u-i) {
j=l;
i=i-1;
l=i+2;
}
else {
j=i+1;
i=u;
u=j-2;
}
auxsort(L, j, i); /* call recursively the smaller one */
} /* repeat the routine for the larger one */
//上面代碼是讓分割后長度較短的數組繼續迭代,長度較長的則繼續通過while進行快排算法,減少遞歸調用的次數
}
報錯的位置源碼:
while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
if (i>u)
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[i] */
}
/* repeat --j until a[j] <= P */
while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
if (j<l)
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[j] */
}
報錯原因和解決
報錯出現的條件:當基點和數組邊界值有相等時,這時如果排序方法sort_comp返回true,則會造成數組越界。
local array = {9,15,9,222,10}
--此時的基點是9,和開頭元素相等
table.sort(array, function(a, b)
--遞增
return a <= b
end)
--報錯:invalid order function for sorting
報錯解決:當比較的兩個值相等時,返回false即可。
local array = {9,15,9,222,10}
table.sort(array, function(a, b)
return a < b
end)
補充(當數組中有為nil的元素時)
sort排序的表必須是從1-n連續的,不能有nil。
不然的話,排序會把nil的前一個元素當做尾元素來進行排序。
local array = {9,15,9,nil,10}
table.sort(array, function(a, b)
return a < b
end)
for i = 1, #array do
print(array[i])
end
--[[ 輸出:
9 9 15
]]