写了一个简单的Monkey patcher,可以用这个patcher对已有的module或者table进行Patch,当前只支持对function进行patch。呈上代码
require 'list' require 'string_ext' require 'table_ext' local assert = assert local error = error local getmetatable = getmetatable local next = next local pairs = pairs local rawget = rawget local rawset = rawset local setmetatable = setmetatable local string = string local table = table local tostring = tostring local type = type module(...) -- original functions, this table only saves leaf modules -- e.g, patch a.b, only b is saved in original local original_funcs = {} -- flag patched module -- patched_modules will save intermediate modules if these modules haven't been available when patching -- e.g, when patching a.b with foo, a is not available, patched_modules[_G]['a.b'] == foo -- All intermediate modules must be removed upon unpatching local patched_modules = {} -- original metatable local original_mt = {} local patched_unpatched_map = {} local function __newindex(t, k, v) -- if patched if patched_modules[t][k] then -- if v is a function that has been patched if type(patched_modules[t][k]) == 'function' and type(v) == 'function' then original_funcs[t] = original_funcs[t] or {} original_funcs[t][k] = v local patched_func = patched_modules[t][k] patched_unpatched_map[patched_func] = v else rawset(t, k, v) end elseif type(v) == 'table' then rawset(t, k, v) -- check if we have any unpatched functions need to patch for name, func in pairs(patched_modules[t]) do -- if we are patching this module which is k if name:match('^' .. k) then -- patch the module patch(v, name:match('%.(.+)'), func) end end else rawset(t, k, v) end end local function create_newindex(old_newindex) return function (t, k, v) __newindex(t, k, v) -- if old_newindex saves k in t then our patch can no longer work if old_newindex then if type(old_newindex) == 'table' then old_newindex[k] = v elseif type(old_newindex) == 'function' then old_newindex(t, k, v) else error('unknown type of __newindex') end end end end local function __index(t, k) if patched_modules[t][k] and type(patched_modules[t][k]) == 'function' then return patched_modules[t][k] end return rawget(t, k) end -- return the patched function instead of the original one local function create_index(old_index) return function (t, k) local v = __index(t, k) if v then return v end if old_index then if type(old_index) == 'table' then return old_index[k] elseif type(old_index) == 'function' then return old_index(t, k) else error('unknown type of __index') end end end end function patch(mod, name, func) assert(mod) assert(type(name) == 'string' and type(func) == 'function') -- find the first module/table when we cannot continue to seach along with name -- Invariant: mod[name_list[i]] ~= nil local sep = '%.' local name_list = sep:split(name) if #name_list == 0 then return false, 'name should not be emtpy' end local i = 1 while mod and i <= #name_list-1 do -- cannot find the name, we need to patch mod if not mod[name_list[i]] then break end mod = mod[name_list[i]] i = i + 1 end assert(mod) local cur_mod_name = name_list[i] -- cannot patch a non-function if mod[cur_mod_name] then if type(mod[cur_mod_name]) ~= 'function' then return false, string.format('Cannot patch %s of type %s', name, type(mod[cur_mod_name])) else -- save original function and reset it, so that we can ensure that we're monitoring the changes to the function -- If this function has been patched, we don't need to save it, as it may have been saved by __newindex if not patched_modules[mod] or not patched_modules[mod][cur_mod_name] then original_funcs[mod] = original_funcs[mod] or {} original_funcs[mod][cur_mod_name] = mod[cur_mod_name] mod[cur_mod_name] = nil end end end -- set our own metatable, so that later when a member is assigned, we can check if we need to patch local mt = getmetatable(mod) or {} local old_index, old_newindex -- to avoid recursion on our own meta function if not patched_modules[mod] then if mt.__index then old_index = mt.__index original_mt[mod] = original_mt[mod] or {} original_mt[mod].__index = mt.__index end if mt.__newindex then old_newindex = mt.__newindex original_mt[mod] = original_mt[mod] or {} original_mt[mod].__newindex = mt.__newindex end -- we only need to patch which is the very mod that name belongs to -- so that we can return the patched function if i == #name_list then mt.__index = create_index(old_index) end -- if we cannot find the module in mod, we need to patch it -- so that we can be notified when a data is assigned to the mod mt.__newindex = create_newindex(old_newindex) setmetatable(mod, mt) end if i == #name_list then if original_funcs[mod] and original_funcs[mod][cur_mod_name] then patched_unpatched_map[func] = original_funcs[mod][cur_mod_name] end end -- mark as patched patched_modules[mod] = patched_modules[mod] or {} -- new name of function to patch local new_name = table.concat(name_list, '.', i) patched_modules[mod][new_name] = func return true end -- unpatch a function of mod, if name does not exist, return false -- there's no need to specify the exact same name as patch is callled function unpatch(mod, name) if not mod then -- unpatch all modules for m in pairs(patched_modules) do unpatch(m) end elseif not name then -- unpatch this module if patched_modules[mod] then for funcname in pairs(patched_modules[mod]) do unpatch(mod, funcname) end else return false, 'Cannot unpatch ' .. tostring(mod) end else assert(mod and type(name) == 'string') local sep = '%.' local name_list = sep:split(name) local i = 1 -- find the module that we have patched, this module maybe an intermediate module -- Invariant: name_list[i] belongs to mod while i < #name_list and mod[name_list[i]] do -- remove any intermediate info patched_modules[mod][table.concat(name_list, '.', i)] = nil mod = mod[name_list[i]] i = i + 1 end local new_name = table.concat(name_list, '.', i) -- if this moudule is unpatched, or the function are unpatched if not patched_modules[mod] or not patched_modules[mod][new_name] then return false, 'Cannot find ' .. name .. ' to unpatch, it might be an un-patched function' end -- restore the original function if i == #name_list then local patched_func = patched_modules[mod][name_list[i]] patched_unpatched_map[patched_func] = nil -- for now, mod still have our metatable, so we need to get pass rawset(mod, name_list[i], original_funcs[mod][name_list[i]]) original_funcs[mod][name_list[i]] = nil end patched_modules[mod][new_name] = nil -- if this is last patched function, we need to restore the metatable if not next(patched_modules[mod]) then local mt = getmetatable(mod) assert(mt) if original_mt[mod] then mt.__index = original_mt[mod].__index mt.__newindex = original_mt[mod].__newindex original_mt[mode] = nil else mt.__index = nil mt.__newindex = nil end setmetatable(mod, mt) patched_modules[mod] = nil end return true end end function get_unpatched_func(mod, name) assert(mod) if name then assert(type(name) == 'string', 'If mod and name are present, mod must be a module, name must be the original function name') local sep = '%.' local name_list = sep:split(name) local i = 1 -- find the module that we have patched, this module maybe an intermediate module -- Invariant: name_list[i] belongs to mod while i < #name_list and mod[name_list[i]] do mod = mod[name_list[i]] i = i + 1 end local new_name = table.concat(name_list, '.', i) return original_funcs[mod] and original_funcs[mod][new_name] else assert(type(mod) == 'function', 'If name is not present, mod must be a patched function') return patched_unpatched_map[mod] end end