无注册表的COM调用

对于COM,一般用CoCreateInstance来创建对象,这就要求COM的dll要用regsvr32注册,因为CoCreateInstance是要读注册表信息完成相应操作的。
CoCreateInstance主要的工作是读注册表,然后CoLoadLibrary,调用DllGetClassObject,最后CreateInstance创建对象。
如果我们自己来实现CoCreateInstance的工作,就可以实现无注册表的COM调用。
假设有个简单COM,接口很简单,就是个减法函数。
普通COM的代码是这样的:
//普通COM
void TestCom1()
{
cout<<"TestCom1"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;
HRESULT hr = CoInitialize(NULL); //初始化COM
//使用SUCCEEDED宏并检查我们是否能得到一个接口指针
if(SUCCEEDED(hr))
{
hr = CoCreateInstance(CLSID_RCom,
NULL,
CLSCTX_INPROC_SERVER,
IID_IRCom,
(void **)&iCom);
//如果成功,则调用Minus方法,否则显示相应的出错信息
if(SUCCEEDED(hr))
{
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout << "CoCreateInstance Failed." << endl;
}
}
CoUninitialize();//释放COM
}
如果不用注册表:
//无注册表COM
void TestCom2()
{
cout<<"TestCom2"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;
char *pDLLName = "ComFree.dll";
wchar_t szDLLPath[MAX_PATH];
MultiByteToWideChar(CP_ACP, NULL, pDLLName, strlen(pDLLName)+1, szDLLPath, MAX_PATH);
HMODULE hModule = CoLoadLibrary(szDLLPath, TRUE);
if(hModule)
{
pDllGetClassObject pfGetClassObj = (pDllGetClassObject)GetProcAddress(hModule, "DllGetClassObject");
if (pfGetClassObj)
{
IClassFactory *pFac;
HRESULT hr = pfGetClassObj(CLSID_RCom,IID_IClassFactory,(void **)&pFac);
if (SUCCEEDED(hr))
{
hr = pFac->CreateInstance(NULL, IID_IRCom, (void **)&iCom);
pFac->Release();
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout<<hr<<endl;
cout << "DllGetClassObject Failed." << endl;
}
}
}
else
{
cout << "CoLoadLibrary Failed." << endl;
}
}
于是我们想,能不能不改变调用的代码来实现无注册表的COM。
万能的Hook是可以做到的。
Hook Ole32.dll的CoCreateInstance方法。
完整代码:
#include<iostream>
using namespace std;
#include "ComFree_i.h"
#include "ComFree_i.c"
#include <Dbghelp.h>

#pragma comment(lib,"ole32.lib")
#pragma comment(lib,"Dbghelp.lib")
typedef HRESULT (__stdcall *pDllGetClassObject)(REFCLSID rclsid, REFIID riid, LPVOID * ppv);

static PVOID sm_pvMaxAppAddr;
const BYTE cPushOpCode = 0x68;

//普通COM
void TestCom1()
{
cout<<"TestCom1"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;
HRESULT hr = CoInitialize(NULL); //初始化COM
//使用SUCCEEDED宏并检查我们是否能得到一个接口指针
if(SUCCEEDED(hr))
{
hr = CoCreateInstance(CLSID_RCom,
NULL,
CLSCTX_INPROC_SERVER,
IID_IRCom,
(void **)&iCom);
//如果成功,则调用Minus方法,否则显示相应的出错信息
if(SUCCEEDED(hr))
{
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout << "CoCreateInstance Failed." << endl;
}
}
CoUninitialize();//释放COM
}

//无注册表COM
void TestCom2()
{
cout<<"TestCom2"<<endl;
//声明HRESULT和IRCom接口指针
IRCom* iCom = NULL;

char *pDLLName = "ComFree.dll";
wchar_t szDLLPath[MAX_PATH];
MultiByteToWideChar(CP_ACP, NULL, pDLLName, strlen(pDLLName)+1, szDLLPath, MAX_PATH);
HMODULE hModule = CoLoadLibrary(szDLLPath, TRUE);
if(hModule)
{
pDllGetClassObject pfGetClassObj = (pDllGetClassObject)GetProcAddress(hModule, "DllGetClassObject");
if (pfGetClassObj)
{
IClassFactory *pFac;
HRESULT hr = pfGetClassObj(CLSID_RCom,IID_IClassFactory,(void **)&pFac);
if (SUCCEEDED(hr))
{
hr = pFac->CreateInstance(NULL, IID_IRCom, (void **)&iCom);
pFac->Release();
long ret;
iCom->Minus(8,9,&ret);
cout << "The answer for 8-9 is:" << ret << endl;
iCom->Release();
}
else
{
cout<<hr<<endl;
cout << "DllGetClassObject Failed." << endl;
}
}
}
else
{
cout << "CoLoadLibrary Failed." << endl;
}

}
//HOOK,使普通COM变成无注册表COM
int CompareStringNoCase(const char* dst, const char* src)
{
int f, l;
do
{
f = (unsigned char)(*(dst++));
if ((f >= 'A') && (f <= 'Z'))
f -= ('A' - 'a');

l = (unsigned char)(*(src++));
if ((l >= 'A') && (l <= 'Z'))
l -= ('A' - 'a');
}
while ( f && (f == l) );

return (f - l);
}


HRESULT WINAPI HookCoCreateInstance(REFCLSID rclsid,LPUNKNOWN pUnkOuter,DWORD dwClsContext,REFIID riid,LPVOID * ppv)
{
cout<<"HookCoCreateInstance"<<endl;
HRESULT hr = NULL;
char *pDLLName = "ComFree.dll";
wchar_t szDLLPath[MAX_PATH];
MultiByteToWideChar(CP_ACP, NULL, pDLLName, strlen(pDLLName)+1, szDLLPath, MAX_PATH);
HMODULE hModule = CoLoadLibrary(szDLLPath, TRUE);
if(hModule)
{
pDllGetClassObject pfGetClassObj = (pDllGetClassObject)GetProcAddress(hModule, "DllGetClassObject");
if (pfGetClassObj)
{
IClassFactory *pFac;
hr = pfGetClassObj(rclsid,IID_IClassFactory,(void **)&pFac);
if (SUCCEEDED(hr))
{
hr = pFac->CreateInstance(NULL, riid, ppv);
pFac->Release();
}
else
{
cout<<hr<<endl;
cout << "DllGetClassObject Failed." << endl;
}
}
}
else
{
cout << "CoLoadLibrary Failed." << endl;
}
return hr;
}
BOOL HookOle()
{
char* pszCalleeModName = "Ole32.dll";
char* pszFuncName = "CoCreateInstance";
PROC pfnCurrent = ::GetProcAddress(::GetModuleHandleA(pszCalleeModName), pszFuncName);
if (NULL == pfnCurrent)
{
HMODULE hmod = ::LoadLibraryA(pszCalleeModName);
if (hmod)
{
pfnCurrent = ::GetProcAddress(::GetModuleHandleA(pszCalleeModName), pszFuncName);
}
}
HMODULE hmodCaller = GetModuleHandle(NULL);
PROC pfnNew = (PROC)HookCoCreateInstance;
try
{
ULONG ulSize;
// Get the address of the module's import section
PIMAGE_IMPORT_DESCRIPTOR pImportDesc =
(PIMAGE_IMPORT_DESCRIPTOR)ImageDirectoryEntryToData(
hmodCaller,
TRUE,
IMAGE_DIRECTORY_ENTRY_IMPORT,
&ulSize
);
// Does this module has import section ?
if (pImportDesc == NULL)
return FALSE;

while (pImportDesc != NULL)
{

// Loop through all descriptors and
// find the import descriptor containing references to callee's functions
while (pImportDesc->Name)
{
LPCSTR lpszName = ((LPCSTR)((PBYTE) hmodCaller + pImportDesc->Name));
if (CompareStringNoCase(lpszName, pszCalleeModName) == 0)
break;   // Found
pImportDesc++;
} // while
// Does this module import any functions from this callee ?
if (pImportDesc->Name == 0)
return FALSE;

PIMAGE_THUNK_DATA pThunk =
(PIMAGE_THUNK_DATA)( (PBYTE) hmodCaller + (UINT_PTR)pImportDesc->FirstThunk );

while (pThunk->u1.Function)
{
PROC* ppfn = (PROC*) &pThunk->u1.Function;

BOOL bFound = (*ppfn == pfnCurrent);

if (!bFound && (*ppfn > sm_pvMaxAppAddr))
{
PBYTE pbInFunc = (PBYTE) *ppfn;
// Is this a wrapper (debug thunk) represented by PUSH instruction?
if (pbInFunc[0] == cPushOpCode)
{
ppfn = (PROC*) &pbInFunc[1];
bFound = (*ppfn == pfnCurrent);
}
}

if (bFound)
{
MEMORY_BASIC_INFORMATION mbi;
::VirtualQuery(ppfn, &mbi, sizeof(MEMORY_BASIC_INFORMATION));
// In order to provide writable access to this part of the
// memory we need to change the memory protection
if (!::VirtualProtect(mbi.BaseAddress, mbi.RegionSize,
PAGE_READWRITE, &mbi.Protect))
{
return FALSE;
}

// Hook the function.
*ppfn = *pfnNew;

// Restore the protection back
DWORD dwOldProtect;
::VirtualProtect(mbi.BaseAddress, mbi.RegionSize,
mbi.Protect, &dwOldProtect);

return TRUE;
}
pThunk++;
}
pImportDesc++;

}
}
catch(...)
{
// do nothing
}

return FALSE;
}

int main(int argc, char* argv[])
{
if(HookOle())//如果把这行注释掉,TestCom1将失败
{
TestCom1();
}
TestCom2();
return 0;
}

客户程序

COM

组件程序(DLL)

CLSID clsid;

IClassFactory* pClf;

IUnknown* pUnknown;

CoInitialize(NULL);

CLSIDFromProgID(“Dictionary.Object”, &clsid);

 

 

 

COM在注册表中查找字典CLSID

 

CoGetClassObject(clsid, CLSCTX_INPROC_SERVER,

NULL,  IDD_IClassFactory, (void **)&pClf);

 

 

 

COM库在内存中查找clsid组件

if(DictComp.dll还没有被装入内存)

{

从注册表中获取组件程序全

路径名”…\DictComp.dll”;

CoLoadLibrary();

}

DllGetClassObject(clsid, IDD_IClassFactory, &pClf);

 

 

 

创建类厂对象CDictionaryFactory,

并返回IClassFactory接口

 

COM库返回IClassFactory接口给客户

 

pClf->CreateInstance(NULL,IDD_IUnknown,(void **)&pUnknown);

 

 

 

 

类厂对象的CreateInstance函数被调用

(通过组件的vtable被客户直接调用)

new操作符构造字典组件对象

new CDictionary, 并返回IUnknown接口指针

客户使用字典组件,通过其接口进行各种操作

pClf->Release();

pUnknown->Release();

 

 

 

 

组件对象的Release函数被调用

if(m_Ref == 0)

{

delete this;

return 0;

}

CoFreeUnusedLibraries()

 

 

 

COM库调用字典组件的引出函数DllCanUnloadNow()

 

 

 

DllCanUnloadNow函数中:

if(不存在字典对象 && 锁计数为0)

   return TRUE;

else

   return FALSE:

 

if(组件DllCanUnloadNow()返回TRUE)

{

    CoFreeLibrary(…);

}

 

CoUninitialize()

 

 

 

COM库释放资源

 

客户程序退出

 

 

阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页