using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Diagnostics;
// 工具定义
// 注意: 需要 "unsafe编译" ([项目属性 - 生成] 勾选"允许不安全代码")
// 仅仅实现了 x64支持
namespace MethodReplace
{
public static class MethodReplaceUtil
{
public static bool ReplaceMethod(MethodInfo method, MethodInfo new_method)
{
if (IntPtr.Size != 8 || new_method.IsVirtual)
{
// 仅仅支持 x64
// 替换用的函数必须是普通函数
return false;
}
RuntimeHelpers.PrepareMethod(method.MethodHandle);
RuntimeHelpers.PrepareMethod(new_method.MethodHandle);
try
{
// 虚函数
if (method.IsVirtual)
{
return ReplaceVirtualMethod(method, new_method);
}
// 普通函数
#if DEBUG
const bool buildModelDebug = true;
#else
const bool buildModelDebug = false;
#endif
if (buildModelDebug && Debugger.IsAttached)
{
return ReplaceMethod_Debug_Attached(method, new_method);
}
else
{
return ReplaceMethod_Release(method, new_method);
}
}
catch(Exception e)
{
return false;
}
}
// [1]普通成员函数
static bool ReplaceMethod_Release(MethodInfo method, MethodInfo new_method)
{
unsafe
{
long* inj = (long*)new_method.MethodHandle.Value.ToPointer() + 1;
long* tar = (long*)method.MethodHandle.Value.ToPointer() + 1;
*tar = *inj;
}
return true;
}
// [2]普通成员函数(Debug且Attached处理)
static bool ReplaceMethod_Debug_Attached(MethodInfo method, MethodInfo new_method)
{
unsafe
{
long* inj = (long*)new_method.MethodHandle.Value.ToPointer() + 1;
long* tar = (long*)method.MethodHandle.Value.ToPointer() + 1;
byte* injInst = (byte*)*inj;
byte* tarInst = (byte*)*tar;
int* injSrc = (int*)(injInst + 1);
int* tarSrc = (int*)(tarInst + 1);
*tarSrc = *injSrc + ((int)injInst - (int)tarInst);
}
return true;
}
// [3]处理虚函数
// method 是虚函数
// new_method 是普通函数
static bool ReplaceVirtualMethod(MethodInfo method, MethodInfo new_method)
{
unsafe
{
UInt64* methodDesc = (UInt64*)(method.MethodHandle.Value.ToPointer());
int index = (int)(((*methodDesc) >> 32) & 0xFF);
ulong* classStart = (ulong*)method.DeclaringType.TypeHandle.Value.ToPointer();
classStart += 8;
classStart = (ulong*)*classStart;
ulong* tar = classStart + index;
ulong* inj = (ulong*)new_method.MethodHandle.Value.ToPointer() + 1;
*tar = *inj;
}
return true;
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
// 测试
namespace CSMethodLoad
{
// 准备替换成员函数的类
class ClassA
{
public int data = 100;
// 普通函数
public void show(int n)
{
Console.WriteLine("show sum=" + (data + n));
}
// 虚函数
public virtual void show_virtual(int n)
{
Console.WriteLine("show_virtual sum=" + (data + n));
}
}
// 提供替换函数的类
// 从 ClassA派生, 但是不定义任何数据成员和虚函数
// 以便数据成员和虚函数列表兼容
// 定义一些普通成员函数用来替换ClassA的成员函数
class ClassB : ClassA
{
// 用于替换 show
public void show_new(int n)
{
Console.WriteLine("show sum=" + (data + n) + " [new]");
}
// 用于替换 show_virtual
// 注意: 这个函数不是虚函数
public void show_virtual_new(int n)
{
Console.WriteLine("show_virtual sum=" + (data + n) + " [new]");
}
}
class Program
{
static void Main(string[] args)
{
Console.WriteLine("---- 原始函数调用 ----");
var a = new ClassA();
a.show(200);
a.show_virtual(300);
Console.WriteLine("");
Console.WriteLine("---- 用 show_new 替换掉 show ----");
var show = typeof(ClassA).GetMethod("show");
var show_new = typeof(ClassB).GetMethod("show_new");
MethodReplace.MethodReplaceUtil.ReplaceMethod(show, show_new);
Console.WriteLine("---- 用 show_virtual_new 替换掉 show_virtual ----");
var show_virtual = typeof(ClassA).GetMethod("show_virtual");
var show_virtual_new = typeof(ClassB).GetMethod("show_virtual_new");
MethodReplace.MethodReplaceUtil.ReplaceMethod(show_virtual, show_virtual_new);
Console.WriteLine("");
Console.WriteLine("---- 替换之后调用 ----");
a.show(200);
a.show_virtual(300);
Console.WriteLine("");
Console.WriteLine("结束");
Console.ReadKey(true);
}
}
}
/*
运行结果
---- 原始函数调用 ----
show sum=300
show_virtual sum=400
---- 用 show_new 替换掉 show ----
---- 用 show_virtual_new 替换掉 show_virtual ----
---- 替换之后调用 ----
show sum=300 [new]
show_virtual sum=400 [new]
结束
*/