Monkey patcher for Lua

写了一个简单的Monkey patcher,可以用这个patcher对已有的module或者table进行Patch,当前只支持对function进行patch。呈上代码

require 'list'
require 'string_ext'
require 'table_ext'

local assert       = assert
local error        = error
local getmetatable = getmetatable
local next         = next
local pairs        = pairs
local rawget       = rawget
local rawset       = rawset
local setmetatable = setmetatable
local string       = string
local table        = table
local tostring     = tostring
local type         = type

module(...)

-- original functions, this table only saves leaf modules
-- e.g, patch a.b, only b is saved in original
local original_funcs = {}
-- flag patched module
-- patched_modules will save intermediate modules if these modules haven't been available when patching
-- e.g, when patching a.b with foo, a is not available, patched_modules[_G]['a.b'] == foo
-- All intermediate modules must be removed upon unpatching
local patched_modules = {}
-- original metatable
local original_mt = {}

local patched_unpatched_map = {}

local function __newindex(t, k, v)
    -- if patched
    if patched_modules[t][k] then
        -- if v is a function that has been patched
        if type(patched_modules[t][k]) == 'function' and type(v) == 'function' then
            original_funcs[t] = original_funcs[t] or {}
            original_funcs[t][k] = v

            local patched_func = patched_modules[t][k]
            patched_unpatched_map[patched_func] = v
        else
            rawset(t, k, v)
        end
    elseif type(v) == 'table' then
        rawset(t, k, v)
        -- check if we have any unpatched functions need to patch
        for name, func in pairs(patched_modules[t]) do
            -- if we are patching this module which is k
            if name:match('^' .. k) then
                -- patch the module
                patch(v, name:match('%.(.+)'), func)
            end
        end
    else
        rawset(t, k, v)
    end
end

local function create_newindex(old_newindex)
    return function (t, k, v)
        __newindex(t, k, v)
        -- if old_newindex saves k in t then our patch can no longer work
        if old_newindex then
            if type(old_newindex) == 'table' then
                old_newindex[k] = v
            elseif type(old_newindex) == 'function' then
                old_newindex(t, k, v)
            else
                error('unknown type of __newindex')
            end
        end
    end
end

local function __index(t, k)
    if patched_modules[t][k] and type(patched_modules[t][k]) == 'function' then
        return patched_modules[t][k]
    end
    return rawget(t, k)
end

-- return the patched function instead of the original one
local function create_index(old_index)
    return function (t, k)
        local v = __index(t, k)
        if v then return v end
        if old_index then
            if type(old_index) == 'table' then
                return old_index[k]
            elseif type(old_index) == 'function' then
                return old_index(t, k)
            else
                error('unknown type of __index')
            end
        end
    end
end

function patch(mod, name, func)
    assert(mod)
    assert(type(name) == 'string' and type(func) == 'function')
    -- find the first module/table when we cannot continue to seach along with name
    -- Invariant: mod[name_list[i]] ~= nil
    local sep = '%.'
    local name_list = sep:split(name)
    if #name_list == 0 then
        return false, 'name should not be emtpy'
    end

    local i = 1
    while mod and i <= #name_list-1 do
        -- cannot find the name, we need to patch mod
        if not mod[name_list[i]] then
            break
        end
        mod = mod[name_list[i]]
        i = i + 1
    end
    assert(mod)
    local cur_mod_name = name_list[i]
  
    -- cannot patch a non-function
    if mod[cur_mod_name] then
        if type(mod[cur_mod_name]) ~= 'function' then
            return false, string.format('Cannot patch %s of type %s', name, type(mod[cur_mod_name]))
        else
            -- save original function and reset it, so that we can ensure that we're monitoring the changes to the function
            -- If this function has been patched, we don't need to save it, as it may have been saved by __newindex
            if not patched_modules[mod] or not patched_modules[mod][cur_mod_name] then
                original_funcs[mod] = original_funcs[mod] or {}
                original_funcs[mod][cur_mod_name] = mod[cur_mod_name]
                mod[cur_mod_name] = nil
            end
        end
    end

    -- set our own metatable, so that later when a member is assigned, we can check if we need to patch
    local mt = getmetatable(mod) or {}
    local old_index, old_newindex
    -- to avoid recursion on our own meta function
    if not patched_modules[mod] then
        if mt.__index then
            old_index = mt.__index
            original_mt[mod] = original_mt[mod] or {}
            original_mt[mod].__index = mt.__index
        end
        if mt.__newindex then
            old_newindex = mt.__newindex
            original_mt[mod] = original_mt[mod] or {}
            original_mt[mod].__newindex = mt.__newindex
        end

        -- we only need to patch which is the very mod that name belongs to
        -- so that we can return the patched function
        if i == #name_list then
            mt.__index = create_index(old_index)

        end
        -- if we cannot find the module in mod, we need to patch it
        -- so that we can be notified when a data is assigned to the mod
        mt.__newindex = create_newindex(old_newindex)
        setmetatable(mod, mt)
    end
    
    if i == #name_list then
        if original_funcs[mod] and original_funcs[mod][cur_mod_name] then
            patched_unpatched_map[func] = original_funcs[mod][cur_mod_name]
        end
    end
    
    -- mark as patched
    patched_modules[mod] = patched_modules[mod] or {}
    -- new name of function to patch
    local new_name = table.concat(name_list, '.', i)
    patched_modules[mod][new_name] = func

    return true
end

-- unpatch a function of mod, if name does not exist, return false
-- there's no need to specify the exact same name as patch is callled
function unpatch(mod, name)
    if not mod then
        -- unpatch all modules
        for m in pairs(patched_modules) do
            unpatch(m)
        end
    elseif not name then
        -- unpatch this module
        if patched_modules[mod] then
            for funcname in pairs(patched_modules[mod]) do
                unpatch(mod, funcname)
            end
        else
            return false, 'Cannot unpatch ' .. tostring(mod)
        end
    else
        assert(mod and type(name) == 'string')

        local sep = '%.'
        local name_list = sep:split(name)
        local i = 1
        -- find the module that we have patched, this module maybe an intermediate module
        -- Invariant: name_list[i] belongs to mod
        while i < #name_list and mod[name_list[i]] do
            -- remove any intermediate info
            patched_modules[mod][table.concat(name_list, '.', i)] = nil
            mod = mod[name_list[i]]
            i = i + 1
        end
        local new_name = table.concat(name_list, '.', i)
        -- if this moudule is unpatched, or the function are unpatched
        if not patched_modules[mod] or not patched_modules[mod][new_name] then
            return false, 'Cannot find ' .. name .. ' to unpatch, it might be an un-patched function'
        end

        -- restore the original function
        if i == #name_list then
            local patched_func = patched_modules[mod][name_list[i]]
            patched_unpatched_map[patched_func] = nil

            -- for now, mod still have our metatable, so we need to get pass
            rawset(mod, name_list[i], original_funcs[mod][name_list[i]])
            original_funcs[mod][name_list[i]] = nil
        end
        patched_modules[mod][new_name] = nil

        -- if this is last patched function, we need to restore the metatable
        if not next(patched_modules[mod]) then
            local mt = getmetatable(mod)
            assert(mt)
            if original_mt[mod] then
                mt.__index = original_mt[mod].__index
                mt.__newindex = original_mt[mod].__newindex
                original_mt[mode] = nil
            else
                mt.__index = nil
                mt.__newindex = nil
            end
            setmetatable(mod, mt)
            patched_modules[mod] = nil
        end
        return true
    end
end

function get_unpatched_func(mod, name)
    assert(mod)
    if name then
        assert(type(name) == 'string', 'If mod and name are present, mod must be a module, name must be the original function name')

        local sep = '%.'
        local name_list = sep:split(name)
        local i = 1
        -- find the module that we have patched, this module maybe an intermediate module
        -- Invariant: name_list[i] belongs to mod
        while i < #name_list and mod[name_list[i]] do
            mod = mod[name_list[i]]
            i = i + 1
        end
        local new_name = table.concat(name_list, '.', i)
        return original_funcs[mod] and original_funcs[mod][new_name]
    else
        assert(type(mod) == 'function', 'If name is not present, mod must be a patched function')
        return patched_unpatched_map[mod]
    end
end


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值