Lua的table.sort
首先先上代码 ltablib.c
static int sort (lua_State *L) {
// table的大小为n
lua_Integer n = aux_getn(L, 1, TAB_RW);
if (n > 1) { /* non-trivial interval? */
luaL_argcheck(L, n < INT_MAX, 1, "array too big");
if (!lua_isnoneornil(L, 2)) /* is there a 2nd argument? */
luaL_checktype(L, 2, LUA_TFUNCTION); /* must be a function */
lua_settop(L, 2); /* make sure there are two arguments */
// sort的排序算法
auxsort(L, 1, (IdxT)n, 0);
}
return 0;
}
// lo 当前排序的table的左边界 up为右边界
static void auxsort (lua_State *L, IdxT lo, IdxT up, unsigned int rnd) {
while (lo < up) { /* loop for tail recursion */
IdxT p; /* Pivot index */
IdxT n; /* to be used later */
// 第一部分 保证t[lo] < t[up] 此时如果是 up - lo == 1 即只有两个元素就代表排序完成了
/* sort elements 'lo', 'p', and 'up' */
lua_geti(L, 1, lo);
lua_geti(L, 1, up);
if (sort_comp(L, -1, -2)) /* a[up] < a[lo]? */
set2(L, lo, up); /* swap a[lo] - a[up] */
else
lua_pop(L, 2); /* remove both values */
if (up - lo == 1) /* only 2 elements? */
return; /* already sorted */
// 第二部分 此时t[lo] < t[up]
// p = (lo + up)/2 保证t[lo] < t[p] < t[up] 此时如果是 up - lo == 2 即只有三个元素就代表排序完成了
if (up - lo < RANLIMIT || rnd == 0) /* small interval or no randomize? */
p = (lo + up)/2; /* middle element is a good pivot */
else /* for larger intervals, it is worth a random pivot */
p = choosePivot(lo, up, rnd);
lua_geti(L, 1, p);
lua_geti(L, 1, lo);
// sort_comp是我们提供的排序函数
if (sort_comp(L, -2, -1)) /* a[p] < a[lo]? */
set2(L, p, lo); /* swap a[p] - a[lo] */
else {
lua_pop(L, 1); /* remove a[lo] */
lua_geti(L, 1, up);
if (sort_comp(L, -1, -2)) /* a[up] < a[p]? */
set2(L, p, up); /* swap a[up] - a[p] */
else
lua_pop(L, 2);
}
if (up - lo == 2) /* only 3 elements? */
return; /* already sorted */
// 记录此时t[p]的值,然后将t[p]和t[up - 1]值互换 执行partition
// 然后递归lo到p - 1部分或 p + 1到up部分
lua_geti(L, 1, p); /* get middle element (Pivot) */
lua_pushvalue(L, -1); /* push Pivot */
lua_geti(L, 1, up - 1); /* push a[up - 1] */
set2(L, p, up - 1); /* swap Pivot (a[p]) with a[up - 1] */
p = partition(L, lo, up);
/* a[lo .. p - 1] <= a[p] == P <= a[p + 1 .. up] */
if (p - lo < up - p) { /* lower interval is smaller? */
auxsort(L, lo, p - 1, rnd); /* call recursively for lower interval */
n = p - lo; /* size of smaller interval */
lo = p + 1; /* tail call for [p + 1 .. up] (upper interval) */
}
else {
auxsort(L, p + 1, up, rnd); /* call recursively for upper interval */
n = up - p; /* size of smaller interval */
up = p - 1; /* tail call for [lo .. p - 1] (lower interval) */
}
if ((up - lo) / 128 > n) /* partition too imbalanced? */
rnd = l_randomizePivot(); /* try a new randomization */
} /* tail call auxsort(L, lo, up, rnd) */
}
static IdxT partition (lua_State *L, IdxT lo, IdxT up) {
IdxT i = lo; /* will be incremented before first use */
IdxT j = up - 1; /* will be decremented before first use */
/* loop invariant: a[lo .. i] <= P <= a[j .. up] */
for (;;) {
/* next loop: repeat ++i while a[i] < P */
while ((void)lua_geti(L, 1, ++i), sort_comp(L, -1, -2)) {
if (l_unlikely(i == up - 1)) /* a[i] < P but a[up - 1] == P ?? */
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[i] */
}
/* after the loop, a[i] >= P and a[lo .. i - 1] < P */
/* next loop: repeat --j while P < a[j] */
while ((void)lua_geti(L, 1, --j), sort_comp(L, -3, -1)) {
if (l_unlikely(j < i)) /* j < i but a[j] > P ?? */
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[j] */
}
/* after the loop, a[j] <= P and a[j + 1 .. up] >= P */
if (j < i) { /* no elements out of place? */
/* a[lo .. i - 1] <= P <= a[j + 1 .. i .. up] */
lua_pop(L, 1); /* pop a[j] */
/* swap pivot (a[up - 1]) with a[i] to satisfy pos-condition */
set2(L, up - 1, i);
return i;
}
/* otherwise, swap a[i] - a[j] to restore invariant and repeat */
set2(L, i, j);
}
}
可以清晰的看到lua的排序是使用的快排思想。
在table较小时候(小于RANLIMIT lua5.4.4 为100)此时找到的标记位为P = t[(lo+up)/2],通过partition函数将table处理成小于P和大于P的两部分,并返回此时的P的位置,这种情况就是正常的快排。
在table比较大的时候,此时的标记位会是choosePivot,是随机出来的一个位置。
static IdxT choosePivot (IdxT lo, IdxT up, unsigned int rnd) {
IdxT r4 = (up - lo) / 4; /* range/4 */
IdxT p = rnd % (r4 * 2) + (lo + r4);
lua_assert(lo + r4 <= p && p <= up - r4);
return p;
}
这个处理主要是因为快排的时间复杂度最优的情况下是O(NlogN),最差的情况是O(N^2),使用随机函数可以在大样本范围保证长期期望的时间复杂度是O(NlogN)
用lua实现一遍
用lua实现了一下,可以直接执行,如果有问题欢迎指正
local rawTableSort = table.sort
local MAX_INT = 2147483647
local RANLIMIT = 100
function printt(t)
print(table.concat(t, ", "))
end
function swap(t, lo, up)
local tmp = t[lo]
t[lo] = t[up]
t[up] = tmp
end
function partition(t, lo, up, p)
local i = lo + 1
local j = up - 2
while true do
while t[i] and t[i] < p do
i = i + 1
end
while t[j] and t[j] > p do
j = j - 1
end
if j < i then
swap(t, up - 1, i)
return i
end
swap(t, j, i)
end
end
function choosePivot(lo, up, rnd)
local r4 = (up - lo) / 4
local p = rnd % (r4 * 2) + (lo + r4)
return p
end
table.sort = function(t, sortFunc)
local n = #t
if n > 1 then
if n > MAX_INT then
print("array too big")
return
end
if sortFunc ~= nil then
if type(sortFunc) ~= "function" then
print("bad argument #2 to 'sort'")
return
end
end
local auxsort
auxsort = function(t, lo, up, rnd)
while lo < up do
local p
-- local n
if sortFunc(t[up], t[lo]) then
swap(t, up, lo)
end
if up - lo == 1 then
return
end
if up - lo < RANLIMIT or rnd == 0 then
p = math.floor((lo + up) / 2)
else
p = choosePivot(lo, up, rnd)
end
if sortFunc(t[p], t[lo]) then
swap(t, p, lo)
else
if sortFunc(t[up], t[p]) then
swap(t, p, up)
end
end
if up - lo == 2 then
return
end
swap(t, p, up - 1)
p = partition(t, lo, up, t[up - 1])
if p - lo < up - p then
auxsort(t, lo, p - 1, rnd)
-- n = p - lo
lo = p + 1
else
auxsort(t, p + 1, up, rnd)
-- n = up - p
up = p - 1
end
-- if (up - lo) / 128 > n then
-- rnd = l_randomizePivot()
-- end
end
end
auxsort(t, 1, n, 0)
end
end
local tab = { 3, 4, 16, 2, 6, 0, 9, 1, 7}
table.sort(tab, function(a, b)
return a < b
end)
printt(tab)