Lua table.sort

Lua的table.sort

首先先上代码 ltablib.c

static int sort (lua_State *L) {
  // table的大小为n
  lua_Integer n = aux_getn(L, 1, TAB_RW);
  if (n > 1) {  /* non-trivial interval? */
    luaL_argcheck(L, n < INT_MAX, 1, "array too big");
    if (!lua_isnoneornil(L, 2))  /* is there a 2nd argument? */
      luaL_checktype(L, 2, LUA_TFUNCTION);  /* must be a function */
    lua_settop(L, 2);  /* make sure there are two arguments */
    // sort的排序算法
    auxsort(L, 1, (IdxT)n, 0);
  }
  return 0;
}
// lo 当前排序的table的左边界 up为右边界 
static void auxsort (lua_State *L, IdxT lo, IdxT up, unsigned int rnd) {
  while (lo < up) {  /* loop for tail recursion */
    IdxT p;  /* Pivot index */
    IdxT n;  /* to be used later */
	// 第一部分 保证t[lo] < t[up] 此时如果是 up - lo == 1 即只有两个元素就代表排序完成了
    /* sort elements 'lo', 'p', and 'up' */
    lua_geti(L, 1, lo);
    lua_geti(L, 1, up);
    if (sort_comp(L, -1, -2))  /* a[up] < a[lo]? */
      set2(L, lo, up);  /* swap a[lo] - a[up] */
    else
      lua_pop(L, 2);  /* remove both values */
    if (up - lo == 1)  /* only 2 elements? */
      return;  /* already sorted */

	// 第二部分 此时t[lo] < t[up] 
	// p = (lo + up)/2 保证t[lo] < t[p] < t[up] 此时如果是 up - lo == 2 即只有三个元素就代表排序完成了
    if (up - lo < RANLIMIT || rnd == 0)  /* small interval or no randomize? */
      p = (lo + up)/2;  /* middle element is a good pivot */
    else  /* for larger intervals, it is worth a random pivot */
      p = choosePivot(lo, up, rnd);
    lua_geti(L, 1, p);
    lua_geti(L, 1, lo);
    // sort_comp是我们提供的排序函数
    if (sort_comp(L, -2, -1))  /* a[p] < a[lo]? */
      set2(L, p, lo);  /* swap a[p] - a[lo] */
    else {
      lua_pop(L, 1);  /* remove a[lo] */
      lua_geti(L, 1, up);
      if (sort_comp(L, -1, -2))  /* a[up] < a[p]? */
        set2(L, p, up);  /* swap a[up] - a[p] */
      else
        lua_pop(L, 2);
    }
    if (up - lo == 2)  /* only 3 elements? */
      return;  /* already sorted */

	// 记录此时t[p]的值,然后将t[p]和t[up - 1]值互换 执行partition
	// 然后递归lo到p - 1部分或 p + 1到up部分
    lua_geti(L, 1, p);  /* get middle element (Pivot) */
    lua_pushvalue(L, -1);  /* push Pivot */
    lua_geti(L, 1, up - 1);  /* push a[up - 1] */
    set2(L, p, up - 1);  /* swap Pivot (a[p]) with a[up - 1] */
    p = partition(L, lo, up);
    /* a[lo .. p - 1] <= a[p] == P <= a[p + 1 .. up] */
    if (p - lo < up - p) {  /* lower interval is smaller? */
      auxsort(L, lo, p - 1, rnd);  /* call recursively for lower interval */
      n = p - lo;  /* size of smaller interval */
      lo = p + 1;  /* tail call for [p + 1 .. up] (upper interval) */
    }
    else {
      auxsort(L, p + 1, up, rnd);  /* call recursively for upper interval */
      n = up - p;  /* size of smaller interval */
      up = p - 1;  /* tail call for [lo .. p - 1]  (lower interval) */
    }
    if ((up - lo) / 128 > n) /* partition too imbalanced? */
      rnd = l_randomizePivot();  /* try a new randomization */
  }  /* tail call auxsort(L, lo, up, rnd) */
}

static IdxT partition (lua_State *L, IdxT lo, IdxT up) {
  IdxT i = lo;  /* will be incremented before first use */
  IdxT j = up - 1;  /* will be decremented before first use */
  /* loop invariant: a[lo .. i] <= P <= a[j .. up] */
  for (;;) {
    /* next loop: repeat ++i while a[i] < P */
    while ((void)lua_geti(L, 1, ++i), sort_comp(L, -1, -2)) {
      if (l_unlikely(i == up - 1))  /* a[i] < P  but a[up - 1] == P  ?? */
        luaL_error(L, "invalid order function for sorting");
      lua_pop(L, 1);  /* remove a[i] */
    }
    /* after the loop, a[i] >= P and a[lo .. i - 1] < P */
    /* next loop: repeat --j while P < a[j] */
    while ((void)lua_geti(L, 1, --j), sort_comp(L, -3, -1)) {
      if (l_unlikely(j < i))  /* j < i  but  a[j] > P ?? */
        luaL_error(L, "invalid order function for sorting");
      lua_pop(L, 1);  /* remove a[j] */
    }
    /* after the loop, a[j] <= P and a[j + 1 .. up] >= P */
    if (j < i) {  /* no elements out of place? */
      /* a[lo .. i - 1] <= P <= a[j + 1 .. i .. up] */
      lua_pop(L, 1);  /* pop a[j] */
      /* swap pivot (a[up - 1]) with a[i] to satisfy pos-condition */
      set2(L, up - 1, i);
      return i;
    }
    /* otherwise, swap a[i] - a[j] to restore invariant and repeat */
    set2(L, i, j);
  }
}

可以清晰的看到lua的排序是使用的快排思想。
在table较小时候(小于RANLIMIT lua5.4.4 为100)此时找到的标记位为P = t[(lo+up)/2],通过partition函数将table处理成小于P和大于P的两部分,并返回此时的P的位置,这种情况就是正常的快排。
在table比较大的时候,此时的标记位会是choosePivot,是随机出来的一个位置。

static IdxT choosePivot (IdxT lo, IdxT up, unsigned int rnd) {
  IdxT r4 = (up - lo) / 4;  /* range/4 */
  IdxT p = rnd % (r4 * 2) + (lo + r4);
  lua_assert(lo + r4 <= p && p <= up - r4);
  return p;
}

这个处理主要是因为快排的时间复杂度最优的情况下是O(NlogN),最差的情况是O(N^2),使用随机函数可以在大样本范围保证长期期望的时间复杂度是O(NlogN)

用lua实现一遍

用lua实现了一下,可以直接执行,如果有问题欢迎指正

local rawTableSort = table.sort
local MAX_INT = 2147483647
local RANLIMIT = 100

function printt(t)
    print(table.concat(t, ", "))
end

function swap(t, lo, up)
    local tmp = t[lo]
    t[lo] = t[up]
    t[up] = tmp
end

function partition(t, lo, up, p)
    local i = lo + 1
    local j = up - 2
    while true do
        while t[i] and t[i] < p do
            i = i + 1
        end
        while t[j] and t[j] > p do
            j = j - 1
        end
        if j < i then
            swap(t, up - 1, i)
            return i
        end
        swap(t, j, i)
    end
end

function choosePivot(lo, up, rnd)
    local r4 = (up - lo) / 4
    local p = rnd % (r4 * 2) + (lo + r4)
    return p
end

table.sort = function(t, sortFunc)
    local n = #t
    if n > 1 then
        if n > MAX_INT then
            print("array too big")
            return
        end
        if sortFunc ~= nil then
            if type(sortFunc) ~= "function" then
                print("bad argument #2 to 'sort'")
                return
            end
        end
        local auxsort
        auxsort = function(t, lo, up, rnd)
            while lo < up do
                local p
                -- local n
                if sortFunc(t[up], t[lo]) then
                    swap(t, up, lo)
                end
                if up - lo == 1 then
                    return
                end
                if up - lo < RANLIMIT or rnd == 0 then
                    p = math.floor((lo + up) / 2)
                else
                    p = choosePivot(lo, up, rnd)
                end
                if sortFunc(t[p], t[lo]) then
                    swap(t, p, lo)
                else
                    if sortFunc(t[up], t[p]) then
                        swap(t, p, up)
                    end
                end
                if up - lo == 2 then
                    return
                end
                swap(t, p, up - 1)
                p = partition(t, lo, up, t[up - 1])
                if p - lo < up - p then
                    auxsort(t, lo, p - 1, rnd)
                    -- n = p - lo
                    lo = p + 1
                else
                    auxsort(t, p + 1, up, rnd)
                    -- n = up - p
                    up = p - 1
                end
                -- if (up - lo) / 128 > n then
                --     rnd = l_randomizePivot()
                -- end
            end
        end
        auxsort(t, 1, n, 0)
    end
end


local tab = { 3, 4, 16, 2, 6, 0, 9, 1, 7}

table.sort(tab, function(a, b)
    return a < b
end)

printt(tab)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值