LLVM PASS--虚函数保护

虚函数实现分析

CPP实现:

#include <iostream>

class Test
{
public:
    Test() noexcept
    {
        printf("Test::Test()\n");
    }

    virtual ~Test()
    {
        printf("Virtual ~Test()\n");
    }

    virtual void prointer() = 0;
    virtual void pointf() = 0;
};

class TestA :public Test
{
public:
    TestA() noexcept
    {
        printf("TestA::TestA()\n");
    }

    virtual ~TestA()
    {
        printf("Virtual ~TestA()\n");
    }

    virtual void prointer()
    {
        printf("TestA::prointer\n");
    }

    virtual void pointf()
    {
        printf("TestA::pointf\n");
    }
};

int main()
{
    Test* pTest = new TestA;
    pTest->pointf();
    pTest->prointer();
    delete pTest;

    return 0;
}

执行输出结果:

Test::Test()
TestA::TestA()
TestA::pointf
TestA::prointer
Virtual ~TestA()
Virtual ~Test()

Test*类型的pTest变量是如何感知到TestA变量的pointf实现的呢?
IDA
通过函数指针数组访问的,new的是Test则指针数组指向Test的函数地址,TestA指向TestA的函数地址,即原因
结合IDA分析可知虚函数调用逻辑如下图
struct
示意图

虚函数攻击&保护

攻击原理概述:虚函数属于间接跳转,如果TestA对象的第一个指针(pVirtualTable)被修改则访问的函数即为攻击者所提供
保护原理概述:CFI即控制流完整性保护,以各种角度验证vtable的间接访问是否合法

llvm-pass动态插件部署

见上一篇文章:LLVM 13.1 new Pass插件形式 [for win]
建议自行使用文件硬链接创建clang目录下的插件"链接",方便后期测试执行

虚函数调用的LLVM-IR解析

输入命令clang vcallTest.cpp -S -emit-llvm -o vcallTest.ll 生成该CPP的llvm ir形式的人类可读文件,详细注释如下

# 定义一个返回类型是i32(int),参数为空的main函数
define dso_local i32 @main() #0 {
    # alloca指令在栈帧上分配的局部变量,访问或存储时必须使用load和store指令(后续不赘述)
    # 申请一个i32且4字节对齐的变量,分配给%1,%1类型为i32*
    %1 = alloca i32, align 4

    # 申请一个class.Test*类型的变量,分配给%2
    %2 = alloca %class.Test*, align 8

    # 将i32类型的0存储到%1变量中(nullptr)
    store i32 0, i32* %1, align 4

    # 调用new函数申请8字节的堆内存并将返回值分配给%3
    %3 = call noalias nonnull i8* @"??2@YAPEAX_K@Z"(i64 8) #10

    # 将i8*类型的%3变量转换为class.TestA*类型的变量给到%4
    %4 = bitcast i8* %3 to %class.TestA*

    # 调用TestA构造函数并将返回值分配给%5
    %5 = call %class.TestA* @"??0TestA@@QEAA@XZ"(%class.TestA* nonnull align 8 dereferenceable(8) %4) #11

    # 将class.TestA*类型的变量%4转换为class.Test*类型的变量%6
    %6 = bitcast %class.TestA* %4 to %class.Test*

    # 将%6变量的值存储到%2中
    store %class.Test* %6, %class.Test** %2, align 8

    # 把class.Test**类型的变量%2解引用存储到%7%7 = load %class.Test*, %class.Test** %2, align 8

    # 把class.Test*的变量%7转为void (%class.Test*)***类型变量%8%8 = bitcast %class.Test* %7 to void (%class.Test*)***

    # 解引用%8类型的值存储到%9
    %9 = load void (%class.Test*)**, void (%class.Test*)*** %8, align 8

    # 把%9的类型+2%10
    %10 = getelementptr inbounds void (%class.Test*)*, void (%class.Test*)** %9, i64 2

    # 解引用%10存储至%11
    %11 = load void (%class.Test*)*, void (%class.Test*)** %10, align 8

    # 调用函数%11,参数为%7(this)
    call void %11(%class.Test* nonnull align 8 dereferenceable(8) %7)

    # 解引用%2%12
    %12 = load %class.Test*, %class.Test** %2, align 8

    # 转换class.Test*类型的%12到void (%class.Test*)***类型的%13
    %13 = bitcast %class.Test* %12 to void (%class.Test*)***

    # 解引用%13%14
    %14 = load void (%class.Test*)**, void (%class.Test*)*** %13, align 8

    # 把%14的类型+1%10
    %15 = getelementptr inbounds void (%class.Test*)*, void (%class.Test*)** %14, i64 1

    # 把%10解引用给%16
    %16 = load void (%class.Test*)*, void (%class.Test*)** %15, align 8

    # 调用%16函数,参数为%12(this)
    call void %16(%class.Test* nonnull align 8 dereferenceable(8) %12)

    # 解引用%2%17
    %17 = load %class.Test*, %class.Test** %2, align 8

    # 比较%17null的关系结果存到%18
    %18 = icmp eq %class.Test* %17, null

    # 调到lable%25(null)或者lable%19(not null)
    br i1 %18, label %25, label %19

19:                                               ; preds = %0
    # 转换class.Test*类型的%17为i8* (%class.Test*, i32)***类型的%20
    %20 = bitcast %class.Test* %17 to i8* (%class.Test*, i32)***

    # 解引用%20%21
    %21 = load i8* (%class.Test*, i32)**, i8* (%class.Test*, i32)*** %20, align 8

    # 把%21+0%22
    %22 = getelementptr inbounds i8* (%class.Test*, i32)*, i8* (%class.Test*, i32)** %21, i64 0

    # 解引用%22%23
    %23 = load i8* (%class.Test*, i32)*, i8* (%class.Test*, i32)** %22, align 8

    # 调用%23函数,参数为%17(this,1)
    %24 = call i8* %23(%class.Test* nonnull align 8 dereferenceable(8) %17, i32 1) #11

    # 跳转到label%25
    br label %25

25:                                               ; preds = %19, %0
    # 返回i32类型的0
    ret i32 0
}

修改PASS观察输出

修改上一篇文章介绍的pass中的MyCustomPass::run函数为:

PreservedAnalyses run(Module& M, ModuleAnalysisManager& AM)
{
    for (auto& function : M) {
        if (!function.getName().compare("main")){
            for (auto& block : function){
                for(auto inst = block.begin(); inst != block.end(); ++inst){
                    errs() << ">>>>" << inst->getOpcodeName() << "\n";
                }
                errs() << "----------------------------" << "\n";
            }
        }
    }
    return PreservedAnalyses::all();
}

观察输出与上面的LLVM-LR文件的关联

clang vcallTest.cpp -o vcallTest.exe
>>>>alloca
>>>>alloca
>>>>store
>>>>call
>>>>bitcast
>>>>call
>>>>bitcast
>>>>store
>>>>load
>>>>bitcast
>>>>load
>>>>getelementptr
>>>>load
>>>>call
>>>>load
>>>>bitcast
>>>>load
>>>>getelementptr
>>>>load
>>>>call
>>>>load
>>>>icmp
>>>>br
----------------------------
>>>>bitcast
>>>>load
>>>>getelementptr
>>>>load
>>>>call
>>>>br
----------------------------
>>>>ret
----------------------------

可以看到一模一样,所以cpp—>llvmir—>pass—>link大概是这样的步骤

虚函数的识别

上述问题就是目前难度最大的一个问题了,从IR中可以看到明显的虚函数调用痕迹,但是这里会有个充要条件的问题。
可以看到这里的每个虚函数调用都会存在一组下列调用链,但是下述调用链不一定是虚函数调用

>>>>bitcast			要将类实例指针(this)类型转换为虚函数的指针的指针的指针类型
>>>>load			解应用this得到vtable指针
>>>>getelementptr	计算vtable[index]得到要调用的虚函数的指针的指针
>>>>load			解引用得到虚函数的指针
>>>>call			调用虚函数

更严谨的判断方式,以上调用链 + 第一个bitcast为:类指针—>函数的三重指针
如此只要不刻意伪造看上去很难与正常代码碰撞

虚函数保护思路

  1. 类似safeseh的形式,将每个识别出的虚函数的vtable指针记录下来,然后放到合法list中去,调用前比对是否合法
  2. CFIXX(NDSS’18):构造函数中把vtable加密了,虚函数调用前解密(改源码实现为用pass实现)

CFIXX思路的简化版实现:半成品

pass代码

#include <llvm/IR/PassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/Pass.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/PassPlugin.h>
#include <llvm/IR/IRBuilder.h>

using namespace llvm;

// test for my custom pass
class MyCustomPass : public PassInfoMixin<MyCustomPass> {
public:
    unsigned int counts = 0;

    bool InsertPrintf(Instruction* inst);
    void InjectVirtualCall(BasicBlock& block);
    PreservedAnalyses run(Module& M, ModuleAnalysisManager& AM);
};

// 在clang里根据配置创建自定义pass,called by PassManagerBuilder::populateModulePassManager
extern "C"  __declspec(dllexport) void __stdcall clangAddCustomPass(ModulePassManager & MPM)
{
    MPM.addPass(MyCustomPass());
}

PreservedAnalyses MyCustomPass::run(Module& M, ModuleAnalysisManager& AM)
{
    for (auto& function : M)
    {
        if (!function.getName().compare("main"))
        {
            for (auto& block : function)
            {
                InjectVirtualCall(block);
            }
        }
    }

    return counts ? PreservedAnalyses::all() : PreservedAnalyses::all();
}

void MyCustomPass::InjectVirtualCall(BasicBlock& block)
{
    int deepth = 0;
    for (auto& inst : block)
    {
        switch (inst.getOpcode())
        {
        case Instruction::BitCast:
            deepth == 0 ? deepth++ : deepth = 0;
            break;
        case Instruction::Load:
            (deepth == 1 || deepth == 3) ? deepth++ : deepth = 0;
            break;
        case Instruction::GetElementPtr:
            deepth == 2 ? deepth++ : deepth = 0;
            break;
        case Instruction::Call:
            deepth == 4 ? deepth++ : deepth = 0;
            break;
        default:
            deepth = 0;
            break;
        }

        if (deepth == 1)
        {
            auto bcInst = dyn_cast<BitCastInst>(&inst);
            auto srcTy = bcInst->getSrcTy();
            auto dstTy = bcInst->getDestTy();
            if (srcTy->getTypeID() != Type::PointerTyID ||
                dstTy->getTypeID() != Type::PointerTyID ||
                srcTy->getPointerElementType()->getTypeID() != Type::StructTyID ||
                dstTy->getPointerElementType()->getTypeID() != Type::PointerTyID ||
                dstTy->getPointerElementType()->getPointerElementType()->getTypeID() != Type::PointerTyID ||
                dstTy->getPointerElementType()->getPointerElementType()->getPointerElementType()->getTypeID() != Type::FunctionTyID)
            {
                deepth = 0;
                continue;
            }
        }
        else if (deepth == 5)
        {
            InsertPrintf(&inst);
            deepth = 0;
            continue;
        }
    }
}

bool MyCustomPass::InsertPrintf(Instruction* inst)
{
    auto mod = inst->getModule();
    if (!mod) return false;

    PointerType* printfArgTy = PointerType::getUnqual(Type::getInt8Ty(mod->getContext()));
    FunctionType* printfTy = FunctionType::get(IntegerType::getInt32Ty(mod->getContext()), printfArgTy, true);
    FunctionCallee funcPtr = mod->getOrInsertFunction("printf", printfTy);
    if (!funcPtr) return false;

    IRBuilder<> builder(inst);
    Constant* constString = ConstantDataArray::getString(mod->getContext(), "vcall counts:%d \n", true);
    Constant* constStringVar = mod->getOrInsertGlobal("helloWord", constString->getType());
    dyn_cast<GlobalVariable>(constStringVar)->setInitializer(constString);
    builder.CreateCall(funcPtr, { constStringVar,builder.getInt32(counts++) });
}

测试cpp代码:

#include <iostream>
#include <map>

std::map<int, int64_t> vtableMap;

class Test
{
public:
    Test() noexcept
    {
        printf("Test::Test()\n");
    }

    virtual ~Test()
    {
        printf("Virtual ~Test()\n");
    }

    virtual void prointer() = 0;
    virtual void pointf() = 0;
};

class TestA :public Test
{
public:
    TestA() noexcept
    {
        int64_t vtablePtr = *(int64_t*)this;
        *(int64_t*)this = rand();
        vtableMap.insert(std::pair<int, int64_t>(*(int*)this, vtablePtr));
        printf("TestA::TestA()\n");
    }

    virtual ~TestA()
    {
        printf("Virtual ~TestA()\n");
    }

    virtual void prointer()
    {
        printf("TestA::prointer\n");
    }

    virtual void pointf()
    {
        printf("TestA::pointf\n");
    }
};

extern "C" uint64_t DecryptPointer(uint64_t vtablePtr)
{
    return vtableMap[vtablePtr];
}
uint64_t DecryptPointer2(uint64_t vtablePtr)
{
    return vtableMap[vtablePtr];
}
uint64_t DecryptPointer3(uint64_t vtablePtr)
{
    return vtableMap[vtablePtr];
}

int main()
{
    srand((unsigned)time(NULL));
    Test* pTest = new TestA;
    pTest->pointf();
    pTest->prointer();
    delete pTest;

    return 0;
}

测试

编译执行结果

clang vcallTest.cpp -o vcallTest.exe
./vcallTest.exe
Test::Test()
TestA::TestA()
TestA::pointf
TestA::prointer
Virtual ~TestA()
Virtual ~Test()

总结

  1. 类的构造函数中把真正的vtable拿走,然后赋予sign值,后续用pass在每个虚函数调用前使用sign恢复出真的vtableptr,此处的实现是写入一个map中,当然因为是简易实现所以没有考虑sign碰撞等问题,但不难解决。按照CFIXX的思路map的内存必须是安全的,依此实现vtableptr的安全性
  2. 类的构造函数此处还是手动插入的vtable指针修改,可以改成自动替换,但是会有个问题就是如何筛选类构造函数以及如何只保护自定义类的构造函数?有个最简单的方式是自定义函数名前N个字节固定,这样规律的类构造函数才进行保护
  3. 本章后续有缘再更新

参考:
CFIXX代码仓库
LLVM pass 实现 C++虚表保护

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值