内存加载支持多依赖的DLL加载器实现解析
基于纯内存的DLL依赖链解决方案
在传统Windows应用开发体系中,LoadLibrary
作为R3 的模块加载的核心函数。闲来无事今天编写了一份替代方案,我在网上看到的一些代码为我省去大部分时间,但其并不能完全替代LoadLibrary
,基本都缺少了模块依赖链的自动化解析与修复,也就是说 若我加载多个存在依赖关系的DLL时会出错。
以典型依赖场景B.dll -> A.dll
为例,LoadLibrary
在背后完成了三项关键操作:
- PEB链表维护:将模块信息注入进程环境块(Process Environment Block)
- IAT隐式修复:通过系统加载器自动重建导入地址表(Import Address Table)
- 依赖拓扑排序:基于广度优先算法解析多级依赖关系
而纯内存加载方案则面临:
- 模块状态对系统不可见,无法通过
GetModuleHandle
查询 - 导入函数地址需手动重建,缺乏标准解析接口
- 多级依赖导致初始化时序错位风险
本文将深入剖析Windows PE加载器的运行时行为,提出一套完整的内存模块元数据管理框架
核心功能解析
1. 多级依赖加载
- 全局模块管理器跟踪已加载模块
- 支持内存模块与系统DLL混合加载
- 自动解析依赖树结构
2. 安全内存管理
- 双重内存分配策略(首选基址/自动分配)
- 动态内存保护设置(按节区属性)
- 异常安全设计(RAII机制)
3. 完整PE解析
- 支持重定位表处理
- 导入表递归解析
- 导出表精确查找
4. 生命周期管理
- 自动调用DllMain入口点
- 统一模块注册机制
- 安全内存释放保障
使用示例
基本加载流程
// 初始化DLL数据
const unsigned char demoDll[] = { /* PE字节数据 */ };
// 创建加载器实例
MemoryDLLLoader loader(demoDll, sizeof(demoDll), "DemoModule");
// 执行加载
if (loader.Load()) {
// 获取导出函数
auto pFunc = loader.GetFunction("ExportFunc");
// 使用函数...
}
依赖管理示例
// 加载依赖模块
const unsigned char depDll[] = { /* 依赖DLL数据 */ };
MemoryDLLLoader depLoader(depDll, sizeof(depDll), "DepModule");
depLoader.Load();
// 主模块自动解析依赖
const unsigned char mainDll[] = { /* 主DLL数据 */ };
MemoryDLLLoader mainLoader(mainDll, sizeof(mainDll), "MainModule");
mainLoader.Load(); // 自动发现已加载的DepModule
处理多模块依赖关系
模块查找算法
static HMODULE FindModule(const char* name) {
return MemoryModuleManager::GetModule(name);
}
处理重定位
bool ProcessRelocations() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
DWORD_PTR delta = reinterpret_cast<DWORD_PTR>(m_hModule) -
ntHeaders->OptionalHeader.ImageBase;
if (delta == 0) return true; // 不需要重定位
// 处理重定位表
IMAGE_DATA_DIRECTORY dir = ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_BASERELOC];
if (dir.VirtualAddress == 0) return false;
IMAGE_BASE_RELOCATION* reloc = reinterpret_cast<IMAGE_BASE_RELOCATION*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress);
while (reinterpret_cast<DWORD_PTR>(reloc) <
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress + dir.Size) {
DWORD_PTR offset = reloc->VirtualAddress;
DWORD count = (reloc->SizeOfBlock - sizeof(IMAGE_BASE_RELOCATION)) / sizeof(WORD);
WORD* items = reinterpret_cast<WORD*>(reloc + 1);
for (DWORD i = 0; i < count; ++i) {
if ((items[i] >> 12) == IMAGE_REL_BASED_HIGHLOW) {
DWORD_PTR* patch = reinterpret_cast<DWORD_PTR*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + offset + (items[i] & 0xFFF));
*patch += delta;
}
}
reloc = reinterpret_cast<IMAGE_BASE_RELOCATION*>(
reinterpret_cast<DWORD_PTR>(reloc) + reloc->SizeOfBlock);
}
return true;
}
处理导入表
bool ResolveImports() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
IMAGE_DATA_DIRECTORY dir = ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
if (dir.VirtualAddress == 0) return true;
IMAGE_IMPORT_DESCRIPTOR* importDesc = reinterpret_cast<IMAGE_IMPORT_DESCRIPTOR*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress);
while (importDesc->Name) {
const char* dllName = reinterpret_cast<const char*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->Name);
// 强制要求依赖项必须已内存加载
HMODULE hDll = FindModule(dllName);
if (!hDll)
{
//若依赖系统DLL则用LoadLibrary
hDll = LoadLibraryA(dllName);
}
if (!hDll) {
std::cerr << "[错误] 依赖项 " << dllName << " 未通过内存加载" << std::endl;
return false;
}
// 处理双Thunk结构(OriginalFirstThunk和FirstThunk)
PIMAGE_THUNK_DATA origThunk = nullptr;
if (importDesc->OriginalFirstThunk != 0) {
origThunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->OriginalFirstThunk);
}
else {
origThunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->FirstThunk);
}
PIMAGE_THUNK_DATA thunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->FirstThunk);
while (origThunk->u1.AddressOfData != 0) {
// 获取函数地址
FARPROC procAddr = nullptr;
if (origThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) {
// 按序号导入
auto ordinal = IMAGE_ORDINAL(origThunk->u1.Ordinal);
procAddr = MemoryGetProcAddress(hDll, (LPCSTR)ordinal);
}
else {
// 按名称导入
auto importByName = reinterpret_cast<IMAGE_IMPORT_BY_NAME*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + origThunk->u1.AddressOfData);
if (g_memoryModules.count(dllName)) { // 内存模块
procAddr = MemoryGetProcAddress(hDll, importByName->Name);
}
else { // 系统模块
procAddr = GetProcAddress(hDll, importByName->Name);
}
}
if (!procAddr) {
std::cerr << "[错误] 无法解析函数地址: " << dllName << std::endl;
return false;
}
// 写入IAT
thunk->u1.Function = reinterpret_cast<ULONG_PTR>(procAddr);
// 移动到下一个Thunk
++origThunk;
++thunk;
}
++importDesc;
}
return true;
}
内存获取函数地址代替GetProcAddress
FARPROC MemoryGetProcAddress(HMODULE hModule, const char* funcName) {
// 获取PE头
BYTE* base = reinterpret_cast<BYTE*>(hModule);
IMAGE_DOS_HEADER* dosHeader = reinterpret_cast<IMAGE_DOS_HEADER*>(base);
IMAGE_NT_HEADERS* ntHeaders = reinterpret_cast<IMAGE_NT_HEADERS*>(base + dosHeader->e_lfanew);
// 定位导出表
IMAGE_EXPORT_DIRECTORY* exports = reinterpret_cast<IMAGE_EXPORT_DIRECTORY*>(
base + ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].VirtualAddress);
DWORD* names = reinterpret_cast<DWORD*>(base + exports->AddressOfNames);
WORD* ordinals = reinterpret_cast<WORD*>(base + exports->AddressOfNameOrdinals);
DWORD* functions = reinterpret_cast<DWORD*>(base + exports->AddressOfFunctions);
// 按名称查找
for (DWORD i = 0; i < exports->NumberOfNames; ++i) {
const char* name = reinterpret_cast<const char*>(base + names[i]);
if (strcmp(name, funcName) == 0) {
return reinterpret_cast<FARPROC>(base + functions[ordinals[i]]);
}
}
return nullptr;
}
设置内存保护
bool SetMemoryProtection() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(ntHeaders);
for (WORD i = 0; i < ntHeaders->FileHeader.NumberOfSections; ++i) {
DWORD protect = 0;
DWORD oldProtect;
DWORD characteristics = section[i].Characteristics;
void* address = reinterpret_cast<void*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + section[i].VirtualAddress);
SIZE_T size = section[i].Misc.VirtualSize;
if (characteristics & IMAGE_SCN_MEM_EXECUTE) {
protect = (characteristics & IMAGE_SCN_MEM_WRITE) ? PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
} else {
protect = (characteristics & IMAGE_SCN_MEM_WRITE) ? PAGE_READWRITE : PAGE_READONLY;
}
if (!VirtualProtect(address, size, protect, &oldProtect)) {
return false;
}
}
return true;
}
完整代码
#include <windows.h>
#include <winnt.h>
#include <iostream>
#include <unordered_map>
std::unordered_map<std::string, HMODULE> g_memoryModules;
class MemoryModuleManager {
public:
static void RegisterModule(const std::string& name, HMODULE hModule) {
GetInstance().m_modules[name] = hModule;
}
static HMODULE GetModule(const std::string& name) {
auto& instance = GetInstance();
if (auto it = instance.m_modules.find(name); it != instance.m_modules.end()) {
return it->second;
}
return nullptr;
}
private:
std::unordered_map<std::string, HMODULE> m_modules;
static MemoryModuleManager& GetInstance() {
static MemoryModuleManager instance;
return instance;
}
};
class MemoryDLLLoader {
public:
MemoryDLLLoader(const unsigned char* dllData, size_t dllSize, const char* moduleName)
: m_hModule(nullptr), m_pData(dllData), m_Size(dllSize), m_name(moduleName) {}
bool Load() {
// 验证PE文件签名
if (!ValidatePEHeader()) {
std::cerr << "Invalid PE file format" << std::endl;
return false;
}
// 分配内存空间
if (!AllocateMemory()) {
std::cerr << "Memory allocation failed" << std::endl;
return false;
}
// 处理内存重定位
if (!ProcessRelocations()) {
std::cerr << "Relocation processing failed" << std::endl;
return false;
}
// 处理导入表
if (!ResolveImports()) {
std::cerr << "Import resolution failed" << std::endl;
return false;
}
// 设置内存保护
if (!SetMemoryProtection()) {
std::cerr << "Memory protection setup failed" << std::endl;
return false;
}
if (!CallDllMain()) {
std::cerr << "Call DllMain failed" << std::endl;
return false;
}
if (m_hModule) {
g_memoryModules[m_name] = m_hModule;
}
if (m_hModule) {
MemoryModuleManager::RegisterModule(m_name, m_hModule);
}
return m_hModule != nullptr;;
}
void* GetFunction(const char* funcName) {
if (!m_hModule) return nullptr;
return GetProcAddress(m_hModule, funcName);
}
~MemoryDLLLoader() {
if (m_hModule) {
// 调用DLL_PROCESS_DETACH
CallDllMain(false);
// 释放内存
VirtualFree(m_hModule, 0, MEM_RELEASE);
}
}
private:
HMODULE m_hModule;
const unsigned char* m_pData;
size_t m_Size;
const char* m_name;
const IMAGE_NT_HEADERS* GetNTHeaders() const {
const IMAGE_DOS_HEADER* pDosHeader =
reinterpret_cast<const IMAGE_DOS_HEADER*>(m_pData);
return reinterpret_cast<const IMAGE_NT_HEADERS*>(
m_pData + pDosHeader->e_lfanew);
}
FARPROC MemoryGetProcAddress(HMODULE hModule, const char* funcName) {
// 获取PE头
BYTE* base = reinterpret_cast<BYTE*>(hModule);
IMAGE_DOS_HEADER* dosHeader = reinterpret_cast<IMAGE_DOS_HEADER*>(base);
IMAGE_NT_HEADERS* ntHeaders = reinterpret_cast<IMAGE_NT_HEADERS*>(base + dosHeader->e_lfanew);
// 定位导出表
IMAGE_EXPORT_DIRECTORY* exports = reinterpret_cast<IMAGE_EXPORT_DIRECTORY*>(
base + ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT].VirtualAddress);
DWORD* names = reinterpret_cast<DWORD*>(base + exports->AddressOfNames);
WORD* ordinals = reinterpret_cast<WORD*>(base + exports->AddressOfNameOrdinals);
DWORD* functions = reinterpret_cast<DWORD*>(base + exports->AddressOfFunctions);
// 按名称查找
for (DWORD i = 0; i < exports->NumberOfNames; ++i) {
const char* name = reinterpret_cast<const char*>(base + names[i]);
if (strcmp(name, funcName) == 0) {
return reinterpret_cast<FARPROC>(base + functions[ordinals[i]]);
}
}
return nullptr;
}
bool ValidatePEHeader() {
// 检查DOS头
const IMAGE_DOS_HEADER* dosHeader = reinterpret_cast<const IMAGE_DOS_HEADER*>(m_pData);
if (dosHeader->e_magic != IMAGE_DOS_SIGNATURE) return false;
// 检查PE头
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
return ntHeaders->Signature == IMAGE_NT_SIGNATURE;
}
bool AllocateMemory() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
DWORD size = ntHeaders->OptionalHeader.SizeOfImage;
// 分配可读可写可执行内存
m_hModule = reinterpret_cast<HMODULE>(VirtualAlloc(
reinterpret_cast<LPVOID>(ntHeaders->OptionalHeader.ImageBase),
size,
MEM_COMMIT | MEM_RESERVE,
PAGE_READWRITE));
if (!m_hModule) {
// 如果首选地址不可用,尝试其他地址
m_hModule = reinterpret_cast<HMODULE>(VirtualAlloc(
nullptr,
size,
MEM_COMMIT | MEM_RESERVE,
PAGE_READWRITE));
}
if (!m_hModule) return false;
// 复制PE头
memcpy(m_hModule, m_pData, ntHeaders->OptionalHeader.SizeOfHeaders);
// 复制节区数据
PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(ntHeaders);
for (WORD i = 0; i < ntHeaders->FileHeader.NumberOfSections; ++i) {
void* dest = reinterpret_cast<void*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + section[i].VirtualAddress);
const void* src = m_pData + section[i].PointerToRawData;
memcpy(dest, src, section[i].SizeOfRawData);
}
return true;
}
bool ProcessRelocations() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
DWORD_PTR delta = reinterpret_cast<DWORD_PTR>(m_hModule) -
ntHeaders->OptionalHeader.ImageBase;
if (delta == 0) return true; // 不需要重定位
// 处理重定位表
IMAGE_DATA_DIRECTORY dir = ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_BASERELOC];
if (dir.VirtualAddress == 0) return false;
IMAGE_BASE_RELOCATION* reloc = reinterpret_cast<IMAGE_BASE_RELOCATION*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress);
while (reinterpret_cast<DWORD_PTR>(reloc) <
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress + dir.Size) {
DWORD_PTR offset = reloc->VirtualAddress;
DWORD count = (reloc->SizeOfBlock - sizeof(IMAGE_BASE_RELOCATION)) / sizeof(WORD);
WORD* items = reinterpret_cast<WORD*>(reloc + 1);
for (DWORD i = 0; i < count; ++i) {
if ((items[i] >> 12) == IMAGE_REL_BASED_HIGHLOW) {
DWORD_PTR* patch = reinterpret_cast<DWORD_PTR*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + offset + (items[i] & 0xFFF));
*patch += delta;
}
}
reloc = reinterpret_cast<IMAGE_BASE_RELOCATION*>(
reinterpret_cast<DWORD_PTR>(reloc) + reloc->SizeOfBlock);
}
return true;
}
static HMODULE FindModule(const char* name) {
return MemoryModuleManager::GetModule(name);
}
bool ResolveImports() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
IMAGE_DATA_DIRECTORY dir = ntHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
if (dir.VirtualAddress == 0) return true;
IMAGE_IMPORT_DESCRIPTOR* importDesc = reinterpret_cast<IMAGE_IMPORT_DESCRIPTOR*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + dir.VirtualAddress);
while (importDesc->Name) {
const char* dllName = reinterpret_cast<const char*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->Name);
// 强制要求依赖项必须已内存加载
HMODULE hDll = FindModule(dllName);
if (!hDll)
{
//若依赖系统DLL则用LoadLibrary
hDll = LoadLibraryA(dllName);
}
if (!hDll) {
std::cerr << "[错误] 依赖项 " << dllName << " 未通过内存加载" << std::endl;
return false;
}
// 处理双Thunk结构(OriginalFirstThunk和FirstThunk)
PIMAGE_THUNK_DATA origThunk = nullptr;
if (importDesc->OriginalFirstThunk != 0) {
origThunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->OriginalFirstThunk);
}
else {
origThunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->FirstThunk);
}
PIMAGE_THUNK_DATA thunk = reinterpret_cast<PIMAGE_THUNK_DATA>(
reinterpret_cast<DWORD_PTR>(m_hModule) + importDesc->FirstThunk);
while (origThunk->u1.AddressOfData != 0) {
// 获取函数地址
FARPROC procAddr = nullptr;
if (origThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) {
// 按序号导入
auto ordinal = IMAGE_ORDINAL(origThunk->u1.Ordinal);
procAddr = MemoryGetProcAddress(hDll, (LPCSTR)ordinal);
}
else {
// 按名称导入
auto importByName = reinterpret_cast<IMAGE_IMPORT_BY_NAME*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + origThunk->u1.AddressOfData);
if (g_memoryModules.count(dllName)) { // 内存模块
procAddr = MemoryGetProcAddress(hDll, importByName->Name);
}
else { // 系统模块
procAddr = GetProcAddress(hDll, importByName->Name);
}
}
if (!procAddr) {
std::cerr << "[错误] 无法解析函数地址: " << dllName << std::endl;
return false;
}
// 写入IAT
thunk->u1.Function = reinterpret_cast<ULONG_PTR>(procAddr);
// 移动到下一个Thunk
++origThunk;
++thunk;
}
++importDesc;
}
return true;
}
bool SetMemoryProtection() {
const IMAGE_NT_HEADERS* ntHeaders = GetNTHeaders();
PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(ntHeaders);
for (WORD i = 0; i < ntHeaders->FileHeader.NumberOfSections; ++i) {
DWORD protect = 0;
DWORD oldProtect;
DWORD characteristics = section[i].Characteristics;
void* address = reinterpret_cast<void*>(
reinterpret_cast<DWORD_PTR>(m_hModule) + section[i].VirtualAddress);
SIZE_T size = section[i].Misc.VirtualSize;
if (characteristics & IMAGE_SCN_MEM_EXECUTE) {
protect = (characteristics & IMAGE_SCN_MEM_WRITE) ? PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
} else {
protect = (characteristics & IMAGE_SCN_MEM_WRITE) ? PAGE_READWRITE : PAGE_READONLY;
}
if (!VirtualProtect(address, size, protect, &oldProtect)) {
return false;
}
}
return true;
}
bool CallDllMain(bool attach = true) {
using DllMainFn = BOOL(WINAPI*)(HINSTANCE, DWORD, LPVOID);
DllMainFn entry = reinterpret_cast<DllMainFn>(
reinterpret_cast<DWORD_PTR>(m_hModule) +
GetNTHeaders()->OptionalHeader.AddressOfEntryPoint);
// 确保入口点有效
if (!entry) return false;
return entry(reinterpret_cast<HINSTANCE>(m_hModule),
attach ? DLL_PROCESS_ATTACH : DLL_PROCESS_DETACH,
nullptr);
}
};