【项目实战】高并发内存池(仿tcmalloc)

【项目实战】高并发内存池(仿tcmalloc)

作者:爱写代码的刚子
时间:2024.2.12
前言:

  • 当前项目是实现一个高并发的内存池,它的原型是google的一个开源项目tcmalloc,tcmalloc全称 Thread-Caching Malloc,即线程缓存的malloc,实现了高效的多线程内存管理,用于替代系统的内存分配相关的函数(malloc、free)。
  • 这个项目是把tcmalloc最核心的框架简化后拿出来,模拟实现出一个自己的高并发内存池,目的就是学习tcamlloc的精华,这种方式有点类似=之前学习STL容器的方式。但是相比STL容器部分, tcmalloc的代码量和复杂度上升了很多。当前另一方面,由于难度的上升,收获和成长也是在这个过程中同步上升。
  • 另一方面tcmalloc是全球大厂google开源的,可以认为当时顶尖的C++高手写出来的,他的知名度也是 非常高的,不少公司都在用它,Go语言直接用它做了自己内存分配器。所以很多程序员是熟悉这个项目的,那么有好处,也有坏处。好处就是把这个项目理解扎实了,会很受面试官的认可。坏处就是面试官可能也比较熟悉项目,对项目会问得比较深,比较细。如果你对项目掌握得不扎实,那么就容易碰钉子。

-----------本博客只提供源码----------

BenchMark.cpp

#include"ConcurrentAlloc.h"

// ntimes 一轮申请和释放内存的次数
// rounds 轮次
void BenchmarkMalloc(size_t ntimes, size_t nworks, size_t rounds)
{
	std::vector<std::thread> vthread(nworks);
	std::atomic<size_t> malloc_costtime = 0;
	std::atomic<size_t> free_costtime = 0;

	for (size_t k = 0; k < nworks; ++k)
	{
		vthread[k] = std::thread([&, k]() {
			std::vector<void*> v;
			v.reserve(ntimes);

			for (size_t j = 0; j < rounds; ++j)
			{
				size_t begin1 = clock();
				for (size_t i = 0; i < ntimes; i++)
				{
					//v.push_back(malloc(16));
					v.push_back(malloc((16 + i) % 8192 + 1));
				}
				size_t end1 = clock();

				size_t begin2 = clock();
				for (size_t i = 0; i < ntimes; i++)
				{
					free(v[i]);
				}
				size_t end2 = clock();
				v.clear();

				malloc_costtime += (end1 - begin1);
				free_costtime += (end2 - begin2);
			}
			});
	}

	for (auto& t : vthread)
	{
		t.join();
	}

	printf("%u个线程并发执行%u轮次,每轮次malloc %u次: 花费:%u ms\n",
		nworks, rounds, ntimes, static_cast<unsigned int>(malloc_costtime));

	printf("%u个线程并发执行%u轮次,每轮次free %u次: 花费:%u ms\n",
		nworks, rounds, ntimes, static_cast<unsigned int>(free_costtime));

	printf("%u个线程并发malloc&free %u次,总计花费:%u ms\n",
		nworks, nworks * rounds * ntimes, static_cast<unsigned int>(malloc_costtime + free_costtime));
}


// 单轮次申请释放次数 线程数 轮次
void BenchmarkConcurrentMalloc(size_t ntimes, size_t nworks, size_t rounds)
{
	std::vector<std::thread> vthread(nworks);
	std::atomic<size_t> malloc_costtime = 0;
	std::atomic<size_t> free_costtime = 0;

	for (size_t k = 0; k < nworks; ++k)
	{
		vthread[k] = std::thread([&]() {
			std::vector<void*> v;
			v.reserve(ntimes);

			for (size_t j = 0; j < rounds; ++j)
			{
				size_t begin1 = clock();
				for (size_t i = 0; i < ntimes; i++)
				{
					//v.push_back(ConcurrentAlloc(16));
					v.push_back(ConcurrentAlloc((16 + i) % 8192 + 1));
				}
				size_t end1 = clock();

				size_t begin2 = clock();
				for (size_t i = 0; i < ntimes; i++)
				{
					ConcurrentFree(v[i]);
				}
				size_t end2 = clock();
				v.clear();

				malloc_costtime += (end1 - begin1);
				free_costtime += (end2 - begin2);
			}
			});
	}

	for (auto& t : vthread)
	{
		t.join();
	}

	printf("%u个线程并发执行%u轮次,每轮次concurrent alloc %u次: 花费:%u ms\n",
		nworks, rounds, ntimes, static_cast<unsigned int>(malloc_costtime));

	printf("%u个线程并发执行%u轮次,每轮次concurrent dealloc %u次: 花费:%u ms\n",
		nworks, rounds, ntimes, static_cast<unsigned int>(free_costtime));

	printf("%u个线程并发concurrent alloc&dealloc %u次,总计花费:%u ms\n",
		nworks, nworks * rounds * ntimes, static_cast<unsigned int>(malloc_costtime + free_costtime));
}

int main()
{
	size_t n = 1000;
	cout << "==========================================================" << endl;
	BenchmarkConcurrentMalloc(n, 4, 10);
	cout << endl << endl;

	BenchmarkMalloc(n, 4, 10);
	cout << "==========================================================" << endl;

	return 0;
}

CentralCache.cpp

#include "CentralCache.h"
#include "PageCache.h"
CentralCache CentralCache::_sInst;
//获取一个非空的span
Span* CentralCache::GetOneSpan(SpanList& list, size_t size)
{
    //查看当前的spanlist中是否还有未分配对象的span
    Span* it = list.Begin();
    while (it != list.End())
    {
        if (it->_freeList != nullptr)
        {
            return it;
        }
        else {
            it = it->_next;
        }
    }

    //先把central cache的桶锁解掉,这样如果其他线程释放内存对象回来,不会阻塞
    list._mtx.unlock();
    //走到这里说没有空闲span了,只能找page cache要
    PageCache::GetInstance()->_pageMtx.lock();
    Span* span = PageCache::GetInstance()->NewSpan(SizeClass::NumMovePage(size));
    span->_isUsed = true;
    span->_objSize = size;
    PageCache::GetInstance()->_pageMtx.unlock();

    //对获取span进行切分,不需要加锁,因为这会其他线程访问不到这个span;

    //计算span的大块内存的起始地址和内存的大小(字节数)末尾地址
    char* start = (char*)(span->_pageId << PAGE_SHIFT);//通过页号,计算该页的起始地址(第一页的起始地址是0),页号*8kb,使用char*是因为地址加减时好控制
    size_t bytes = span->_n << PAGE_SHIFT;
    char* end = start + bytes;
    //把大块内存切成自由链表链接起来(采用尾插)
    //1. 先切一块下来去做头,方便尾插
    //
    span->_freeList = start;
    start += size;
    void* tail = span->_freeList;
    //int i=1;
    while (start < end)
    {
        //++i;
        NextObj(tail) = start;
        tail = NextObj(tail);//tail = start
        start += size;
    }
    NextObj(tail) = nullptr;//链表结尾一定不要忘记置空
    //测试验证,条件断点
    //疑似死循环,可以中断程序,程序会在正在运行的地方停下来
       /*int j = 0;
       void* cur = span->_freeList;
       while (cur)
       {
           cur = NextObj(cur);
           ++j;
       }
       if (j != (bytes/size))
       {
           int x = 0;
       }*/
       

    //切好span以后,需要把span挂到桶里面去的时候,再加锁
    list._mtx.lock();
    list.PushFront(span);
    return span;
}
size_t CentralCache::FetchRangeObj(void*& start, void*& end, size_t batchNum, size_t size)
{
    size_t index = SizeClass::Index(size);
    _spanLists[index]._mtx.lock();

    Span* span = GetOneSpan(_spanLists[index], size);//去_spanList里面找到一个非空span
    assert(span);
    assert(span->_freeList);

    //从span中获取batchNum个对象
    //如果不够batchNum个,有多少拿多少
    start = span->_freeList;
    end = start;
    size_t i = 0;
    size_t actualNum = 1;

    while (i < batchNum - 1 && NextObj(end) != nullptr)
    {
        end = NextObj(end);
        ++i;
        ++actualNum;
    }

    span->_freeList = NextObj(end);
    NextObj(end) = nullptr;
    span->_useCount += actualNum;
    //测试验证,条件断点
      /* int j = 0;
       void* cur = start;
       while (cur)
       {
           cur = NextObj(cur);
           ++j;
       }
       if (actualNum!= j)
       {
           int x = 0;
       }
       */
       
    _spanLists[index]._mtx.unlock();

    return actualNum;
}
void CentralCache::ReleaseListToSpans(void* start, size_t size)
{
    size_t index = SizeClass::Index(size);
    _spanLists[index]._mtx.lock();
    while (start)
    {
        void* next = NextObj(start);
        Span* span = PageCache::GetInstance()->MapObjectToSpan(start);

        NextObj(start) = span->_freeList;
        span->_freeList = start;
        span->_useCount--;
        //说明span的切分出去的所有小块内存都回来了
        //这个span就可以再回收给page cache,pagecache可以再尝试去做前后页的合并。
        if (span->_useCount == 0)
        {
            _spanLists[index].Erase(span);
            span->_freeList = nullptr;
            span->_next = nullptr;
            span->_prev = nullptr;
            //释放span给page cache时,使用page cache的锁就可以了
            //这时把桶锁解掉
            _spanLists[index]._mtx.unlock();
            PageCache::GetInstance()->_pageMtx.lock();
            PageCache::GetInstance()->ReleaseSpanToPageCache(span);
            PageCache::GetInstance()->_pageMtx.unlock();
            _spanLists[index]._mtx.lock();
        }
        start = next;
    }
    _spanLists[index]._mtx.unlock();
}

CentralCache.h

#pragma once

#include "Common.h"
// 单例模式(饿汉)
class CentralCache
{
public:
    static CentralCache* GetInstance()
    {
        return &_sInst;
    }
    // 从中心缓存获取一定数量的对象给thread cache
    size_t FetchRangeObj(void*& start, void*& end, size_t n, size_t byte_size);
    //获取一个非空的Span
    Span* GetOneSpan(SpanList& list, size_t byte_size);

    //将一定数量的对象释放到span跨度
    void ReleaseListToSpans(void* start, size_t byte_size);
private:
    SpanList _spanLists[NFREELISTS];

private:
    CentralCache() {} // 构造函数私有化
    CentralCache& operator=(CentralCache const&) = delete;
    CentralCache(const CentralCache&) = delete; // C++11删除默认成员函数
    static CentralCache _sInst;
};

Common.h

#pragma once
#include <iostream>
#include <assert.h>
#include <thread>
#include <mutex>
#include <algorithm>
#include <unordered_map>
#include <map>

#ifdef _WIN32
#include <windows.h>
#elif __linux__
//macOS、linux
#include <unistd.h>
#include <sys/mman.h>
#include <stdexcept>
#elif __APPLE__
#include <TargetConditionals.h>
#include <unistd.h>
#include <sys/mman.h>
#include <stdexcept>
#endif
using std::cout;
using std::endl;


// 不建议使用宏
static const size_t MAX_BYTES = 256 * 1024;
static const size_t NFREELISTS = 208;
static const size_t NPAGES = 129;
static const size_t PAGE_SHIFT = 13; // 2的13次方,相当于8kb
// #elif _WIN64  //在64位下_WIN32和_WIN64都有定义,所以这样写不对

#ifdef _WIN64 // 所以要先定义_WIN64
typedef unsigned long long PAGE_ID;
#elif _WIN32
typedef size_t PAGE_ID;
#elif defined(__x86_64__) || defined(__amd64__) // linux64
typedef unsigned long long PAGE_ID;
#elif defined(__i386__)                         // linux32
typedef size_t PAGE_ID;
#elif defined(__arm64__) && defined(__APPLE__) // Mac ARM64
typedef unsigned long long PAGE_ID;
#endif
//
inline static void* SystemAlloc(size_t kpage)
{
#ifdef _WIN32
    void* ptr = VirtualAlloc(0, kpage << 13, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
#else
    // macOS、linux下brk mmap等 
    void* ptr = mmap(0, kpage << 13, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
#endif
    if (ptr == nullptr)
        throw std::bad_alloc();
    return ptr;
}

inline static void SystemFree(void* ptr, size_t size)
{
#ifdef _WIN32
    VirtualFree(ptr, 0, MEM_RELEASE);
#else


    if (munmap(ptr, size) == -1) {
        // 处理munmap失败的情况
        perror("munmap");
        exit(EXIT_FAILURE);
    }
#endif
}

static void*& NextObj(void* obj)
{
    return *(void**)obj;
}

// 管理切分好的小对象的自由链表
class FreeList
{
public:
    void Push(void* obj)
    {
        assert(obj);
        // 头插
        NextObj(obj) = _freeList;
        _freeList = obj;
        ++_size;
    }
    void PushRange(void*& start, void*& end, size_t n)
    {
        NextObj(end) = _freeList;
        _freeList = start;

        //测试验证,条件断点
        /*int i = 0;
        void* cur = start;
        while (cur)
        {
            cur = NextObj(cur);
            ++i;
        }
        if (n != i)
        {
            int x = 0;
        }
        */
        _size += n;
    }
    void* Pop()
    {
        assert(_freeList);
        // 头删
        void* obj = _freeList;
        _freeList = NextObj(obj);
        --_size;
        return obj;
    }
    void PopRange(void*& start, void*& end, size_t n)
    {
        assert(n <= _size);
        start = _freeList;
        end = start;
        for (size_t i = 0; i < n - 1; ++i)
        {
            end = NextObj(end);
        }
        _freeList = NextObj(end);
        NextObj(end) = nullptr;
        _size -= n;
    }
    bool Empty()
    {
        return _freeList == nullptr;
    }
    size_t& MaxSize()
    {
        return _maxSize;
    }
    size_t Size()
    {
        return _size;
    }

private:
    void* _freeList = nullptr;
    size_t _maxSize = 1;
    size_t _size = 0;//整形初始化必须给缺省值
};

// 计算对象大小的对齐映射规则
class SizeClass
{

public:
    // 整体控制在最多10%左右的内碎片浪费
    //  [1,128]                  8byte对齐           freelist[0,16)
    //  [128+1,1024]             16byte对齐          freelist[16,72)
    //  [1024+1,8*1024]          128byte对齐         freelist[72,128)
    //  [8*1024+1,64*1024]       1024byte对齐        freelist[128,184)
    //  [64*1024+1,256*1024]     8*1024byte对齐      freelist[184,208)
    /*size_t _RoundUp(size_t size,size_t alignNum)
    {
        size_t alignSize;
        if(size%8!=0)
        {
            alignSize=(size/alignNum+1)*alignNum;
        }
        else{
            alignSize = size;
        }
        return alignSize;
    }*/

    // 位运算提高性能
    static inline size_t _RoundUp(size_t bytes, size_t alignNum)
    {
        return ((bytes + alignNum - 1) & ~(alignNum - 1)); // 消去二进制的小于alignNum的位数,7就变成8,扩大不缩小
    }

    static inline size_t RoundUp(size_t size)
    {
        if (size <= 128)
        {
            return _RoundUp(size, 8);
        }
        else if (size <= 1024)
        {
            return _RoundUp(size, 16);
        }
        else if (size <= 8 * 1024)
        {
            return _RoundUp(size, 128);
        }
        else if (size <= 64 * 1024)
        {
            return _RoundUp(size, 1024);
        }
        else if (size <= 256 * 1024)
        {
            return _RoundUp(size, 8 * 1024);
        }
        else
        {
            return _RoundUp(size, 1 << PAGE_SHIFT);
        }
    }

    /*static inline size_t _Index(size_t bytes,size_t alignNum)
    {
        if(bytes %alignNum==0)
        {
            return bytes / alignNum -1;
        }
        else{
            return bytes / alignNum;
        }
    }*/

    static inline size_t _Index(size_t bytes, size_t align_shift) // align_shift表示2的几次方
    {
        return ((bytes + (1 << align_shift) - 1) >> align_shift) - 1; // 1<< align_shift表示:1左移align_shift位
    }

    static inline size_t Index(size_t bytes)
    {
        assert(bytes <= MAX_BYTES);
        // 每个区间有多少个链
        static int group_array[4] = { 16, 56, 56, 56 };
        if (bytes <= 128)
        {
            return _Index(bytes, 3);
        }
        else if (bytes <= 1024)
        {
            return _Index(bytes - 128, 4) + group_array[0];
        }
        else if (bytes <= 8 * 1024)
        {
            return _Index(bytes - 1024, 7) + group_array[1] + group_array[0];
        }
        else if (bytes <= 64 * 1024)
        {
            return _Index(bytes - 8 * 1024, 10) + group_array[2] + group_array[1] + group_array[0];
        }
        else if (bytes <= 256 * 1024)
        {
            return _Index(bytes - 64 * 1024, 13) + group_array[3] +
                group_array[2] + group_array[1] + group_array[0];
        }
        else
        {
            assert(false);
        }
        return -1;
    }

    // 一次从中心缓存获取多少个
    static size_t NumMoveSize(size_t size)
    {
        assert(size > 0);
        if (size == 0)
            return 0;
        // [2, 512],一次批量移动多少个对象的(慢启动)上限值 // 小对象一次批量上限高
        // 小对象一次批量上限低
        int num = MAX_BYTES / size;
        if (num < 2)
            num = 2;
        if (num > 512)
            num = 512;
        return num;
    }

    // 计算一次向系统获取几个页
    // 单个对象 8byte
    // ...
    // 单个对象 256KB
    static size_t NumMovePage(size_t size)
    {
        size_t num = NumMoveSize(size);
        size_t npage = num * size; // 字节数
        npage >>= PAGE_SHIFT;
        if (npage == 0)
            npage = 1;
        return npage;
    }
};
// 管理多个连续页大块内存跨度结构
struct Span
{
    PAGE_ID _pageId = 0; // 大块内存起始页的页号
    size_t _n = 0;       // 页的数量

    Span* _next = nullptr; // 双向链表的结构
    Span* _prev = nullptr;

    size_t _useCount = 0;      // 切好小块内存,被分配给thread cache的计数
    size_t _objSize = 0;    //切好的小对象的大小
    void* _freeList = nullptr; // 切好的小块内存的自由链表

    bool _isUsed = false;                 //是否被使用
};

// 带头双向循环链表
class SpanList
{
public:
    SpanList()
    {
        _head = new Span;
        _head->_next = _head;
        _head->_prev = _head;
    }

    Span* Begin()
    {
        return _head->_next;
    }

    Span* End()
    {
        return _head;
    }
    bool Empty()
    {
        return _head->_next == _head;
    }
    void PushFront(Span* span)
    {
        Insert(Begin(), span);
    }
    Span* PopFront()
    {
        Span* front = _head->_next;
        Erase(front);//没有返回值
        return front;
    }
    void Insert(Span* pos, Span* newSpan)
    {
        assert(pos);
        assert(newSpan);

        Span* prev = pos->_prev;

        prev->_next = newSpan;
        newSpan->_prev = prev;

        newSpan->_next = pos;
        pos->_prev = newSpan;
    }

    void Erase(Span* pos)
    {
        assert(pos);
        assert(pos != _head);
        //条件断点
        //查看栈帧
        
        /*if (pos == _head)
        {
            int x = 0;
        }*/

        Span* prev = pos->_prev;
        Span* next = pos->_next;

        prev->_next = next;
        next->_prev = prev;

        // 不去delete而是给上一层
    }

private:
    Span* _head;

public:
    std::mutex _mtx; // 桶锁
};

ConcurrentAlloc.h

#pragma once

#include "Common.h"
#include "ThreadCache.h"
#include "PageCache.h"
#include "ObjectPool.h"

static void* ConcurrentAlloc(size_t size)
{

    if (size > MAX_BYTES)
    {
        size_t alignSize = SizeClass::RoundUp(size);
        size_t kpage = alignSize >> PAGE_SHIFT;

        PageCache::GetInstance()->_pageMtx.lock();

        Span* span = PageCache::GetInstance()->NewSpan(kpage);
        span->_objSize = size;
        PageCache::GetInstance()->_pageMtx.unlock();
        void* ptr = (void*)(span->_pageId << PAGE_SHIFT);
        return ptr;
    }
    else
    {
        // 通过TLS每个线程无锁的获取自己的专属的ThreadCache对象
        if (pTLSThreadCache == nullptr)
        {
            //定义成静态的保证全局只有一个

            static ObjectPool<ThreadCache> tcPool;
            pTLSThreadCache = tcPool.New();

            //pTLSThreadCache = new ThreadCache;
        }
        // cout<<std::this_thread::get_id()<<":"<<pTLSThreadCache<<endl;

        return pTLSThreadCache->Allocate(size);
    }
}

static void ConcurrentFree(void* ptr)
{
    Span* span = PageCache::GetInstance()->MapObjectToSpan(ptr);
    size_t size = span->_objSize;
    if (size > MAX_BYTES)
    {

        PageCache::GetInstance()->_pageMtx.lock();
        PageCache::GetInstance()->ReleaseSpanToPageCache(span);
        PageCache::GetInstance()->_pageMtx.unlock();
    }
    else
    {
        assert(pTLSThreadCache);
        pTLSThreadCache->Deallocate(ptr, size);
    }
}

ObjectPool.h

#pragma once
#include "Common.h"

#ifdef _WIN32
#include <windows.h>
#else
//macOS、linux
#include <unistd.h>
#include <sys/mman.h>
#include <stdexcept>
#endif
//定长内存池
/*template<size_t N>
class ObjectPool
{

};*/
/*inline static void* SystemAlloc(size_t kpage)
{
#ifdef _WIN32
    void* ptr = VirtualAlloc(0, kpage<<13,MEM_COMMIT | MEM_RESERVE,PAGE_READWRITE);
#else
// macOS、linux下brk mmap等
    void *ptr = mmap(0, kpage << 13, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
#endif
    if (ptr == nullptr)
        throw std::bad_alloc();
    return ptr;
}
*/



template<class T>
class ObjectPool
{
public:
    T* New()
    {
        T* obj = nullptr;

        //优先把还回来的内存对象,再次重复利用
        if (_freeList)
        {
            void* next = *((void**)_freeList);
            obj = (T*)_freeList;
            _freeList = next;
            return obj;
        }
        else
        {
            if (_remainBytes < sizeof(T))//剩余的内存不够一个对象大小时,则重新开大块空间;
            {
                _remainBytes = 128 * 1024;
                //_memory = (char*)malloc(128 * 1024);
                _memory = (char*)SystemAlloc(_remainBytes >> 13);
                if (_memory == nullptr)
                {
                    throw std::bad_alloc();
                }
                
            }

            obj = (T*) (_memory);
            if (obj == nullptr)
            {
                throw std::bad_alloc();
            }
            size_t objSize = sizeof(T) < sizeof(void*) ? sizeof(void*) : sizeof(T);
            _memory += sizeof(T);

            _remainBytes -= sizeof(T);

        }
        
        //定位new,显示调用T的构造函数初始化
        new(obj) T;
        return obj;
    }

    void Delete(T* obj)
    {
        obj->~T();
        /*if(_freeList ==nullptr)
        {
            _freeList = obj;
            //(int*)obj = nullptr;//将头四个字节转化成nullptr,注意系统是32位还是64位
            //可以用if(sizeof())解决

            //巧妙地解决32位还是64位平台问题:
            *(void**)obj = nullptr; //64位下是8字节,32位下是4字节,只要是2级指针都可以(int**,T**... )
        }
        else{*/
        //头插
        * (void**)obj = _freeList;
        _freeList = obj;
        //}
    }


private:
    char* _memory = nullptr;//指向大块内存的指针
    size_t _remainBytes = 0;//大块内存在切分过程中剩余字节数
    void* _freeList = nullptr;//还回来过程中链接的自由链表的头指针
};

PageCache.cpp

#include "PageCache.h"

PageCache PageCache::_sInst;

Span* PageCache::NewSpan(size_t k)
{
    assert(k > 0 && k < NPAGES);
    //大于128 page的直接向堆申请
    if (k > NPAGES - 1)
    {
        void* ptr = SystemAlloc(k);
        //Span* span = new Span;
        Span* span = _spanPool.New();
        span->_pageId = (PAGE_ID)ptr >> PAGE_SHIFT;
        span->_n = k;

       // _idSpanMap[span->_pageId] = span;
        _idSpanMap.set(span->_pageId, span);
        return span;
    }

    //先检查第k个桶里面有没有span
    if (!_spanLists[k].Empty())
    {
        Span* kSpan =  _spanLists[k].PopFront();
        //建立id和span的映射,方便central cache回收小块内存时,查找对应的span
        for (PAGE_ID i = 0; i < kSpan->_n; ++i)
        {
            _idSpanMap.set(kSpan->_pageId + i,  kSpan);
        }
        return kSpan;
    }

    //检查后面的桶里面有没有span,如果有可以进行切分
    for (size_t i = k + 1; i < NPAGES; ++i)
    {
        if (!_spanLists[i].Empty())
        {
            Span* nSpan = _spanLists[i].PopFront();
            //Span* kSpan = new Span;
            Span* kSpan = _spanPool.New();
            //在nSpan的头部
            //k页的span返回
            //nspan挂到对应映射的位置
            kSpan->_pageId = nSpan->_pageId;
            kSpan->_n = k;

            nSpan->_pageId += k;
            nSpan->_n -= k;

            _spanLists[nSpan->_n].PushFront(nSpan);
            //存储nSpan的首位页号跟nSpan映射,方便page cache回收内存时进行的合并查找
            _idSpanMap.set(nSpan->_pageId, nSpan);
            _idSpanMap.set(nSpan->_pageId + nSpan->_n - 1, nSpan);

            //建立id和span的映射,方便central cache回收小块内存时,查找对应的span
            for (PAGE_ID i = 0; i < kSpan->_n; ++i)
            {
                _idSpanMap.set(kSpan->_pageId + i, kSpan);
            }

            return kSpan;
        }
    }

    //走到这个位置说明后面没有大页的span了
    //这时就去找堆要一个128页的span
    Span* bigSpan = new Span;
    void* ptr = SystemAlloc(NPAGES - 1);
    bigSpan->_pageId = (PAGE_ID)ptr >> PAGE_SHIFT;//计算页号
    bigSpan->_n = NPAGES - 1;

    _spanLists[bigSpan->_n].PushFront(bigSpan);

    return NewSpan(k);//调自己,复用
}
Span* PageCache::MapObjectToSpan(void* obj)
{
    PAGE_ID id = ((PAGE_ID)obj >> PAGE_SHIFT);

    /*std::unique_lock<std::mutex> lock(_pageMtx);//出了函数作用域就会解锁
    auto ret = _idSpanMap.find(id);
    if (ret != _idSpanMap.end())
    {
        return ret->second;
    }
    else {
        assert(false);//有问题
        return nullptr;
    }
    */
    auto ret =(Span*)_idSpanMap.get(id);
    assert(ret != nullptr);
    return ret;
}

void PageCache::ReleaseSpanToPageCache(Span* span)
{
    //大于128 page的直接还给堆
    if (span->_n > NPAGES - 1)
    {
        void* ptr = (void*)(span->_pageId << PAGE_SHIFT);
        SystemFree(ptr, span->_n * 8);
        //SystemFree(ptr);
        //delete span;
        _spanPool.Delete(span);

        return;


    }



    // 对span前后对页,尝试进行合并,缓解内存碎片问题
    while (1)
    {
        PAGE_ID prevId = span->_pageId - 1;
        /*auto ret = _idSpanMap.find(prevId);

        //前面的页号没有,不合并了
        if (ret == _idSpanMap.end())
        {
            break;
        }*/
        auto ret = (Span*)_idSpanMap.get(prevId);
        if (ret == nullptr)
        {
            break;
        }

        //前面相邻页的span在使用,不合并了
        Span* prevSpan = ret;
        if (prevSpan->_isUsed == true)
        {
            break;
        }
        //合并出超过128页的span没办法管理,不合并了
        if (prevSpan->_n + span->_n > NPAGES - 1)
        {
            break;
        }

        span->_pageId = prevSpan->_pageId;
        span->_n += prevSpan->_n;

        _spanLists[prevSpan->_n].Erase(prevSpan);
        //delete prevSpan;
        _spanPool.Delete(prevSpan);
    }

    //向后合并
    while (1)
    {
        PAGE_ID nextId = span->_pageId + span->_n;//后面页的起始页号
        /*auto ret = _idSpanMap.find(nextId);
        if (ret == _idSpanMap.end())
        {
            break;
        }*/
        auto ret = (Span*)_idSpanMap.get(nextId);
        if (ret == nullptr)
        {
            break;
        }

        Span* nextSpan = ret;
        if (nextSpan->_isUsed == true)
        {
            break;
        }

        if (nextSpan->_n + span->_n > NPAGES - 1)
        {
            break;
        }

        span->_n += nextSpan->_n;

        _spanLists[nextSpan->_n].Erase(nextSpan);
        //delete nextSpan;
        _spanPool.Delete(nextSpan);

    }

    _spanLists[span->_n].PushFront(span);
    span->_isUsed = false;

    //_idSpanMap[span->_pageId] = span;
    //_idSpanMap[span->_pageId + span->_n - 1] = span;


    _idSpanMap.set(span->_pageId, span);
    _idSpanMap.set(span->_pageId + span->_n - 1, span);
}

PageCache.h

#pragma once

#include "Common.h"
#include "ObjectPool.h"
#include "PageMap.h"

class PageCache
{
public:
    static PageCache* GetInstance()
    {
        return &_sInst;
    }
    //获取从对象到span的映射
    Span* MapObjectToSpan(void* obj);

    //释放空闲span回到Pagecache,并合并相邻的span
    void ReleaseSpanToPageCache(Span* span);
    Span* NewSpan(size_t K);
    std::mutex _pageMtx;
private:
    ObjectPool<Span> _spanPool;
    SpanList _spanLists[NPAGES];
    //std::unordered_map<PAGE_ID, Span*> _idSpanMap;//整数到整数的映射最好采用基数树(最好的地方是可以在读的时候不加锁)
    //std::map<PAGE_ID, Span*> _idSpanMap;
    TCMalloc_PageMap1<32 - PAGE_SHIFT> _idSpanMap;
    PageCache() {} // 构造函数私有化
    PageCache& operator=(PageCache const&) = delete;
    PageCache(const PageCache&) = delete;
    static PageCache _sInst;
};

PageMap.h

#pragma once
#include "Common.h"
#include "ObjectPool.h"
// Single-level array
template <int BITS> //32λ 32-PAGE_SHIFT �洢ҳ�ŵ�λ�� (2^32 / 2^13 = 2^19)
class TCMalloc_PageMap1 {
private:
    static const int LENGTH = 1 << BITS;
    void** array_;//32λ 2M (2^19 * 4)=2^21
public:
    typedef uintptr_t Number;
    //explicit TCMalloc_PageMap1(void* (*allocator)(size_t)) {
    explicit TCMalloc_PageMap1()
    {
       // array_ = reinterpret_cast<void**>((*allocator)(sizeof(void*) << BITS));
        size_t size = sizeof(void*) << BITS;
        size_t alignSize = SizeClass::_RoundUp(size,1<<PAGE_SHIFT);
        array_ = (void**)SystemAlloc(alignSize>>PAGE_SHIFT);

        memset(array_, 0, sizeof(void*) << BITS);
    }
    // Return the current value for KEY.  Returns NULL if not yet set,
    // or if k is out of range.
    void* get(Number k) const {
        if ((k >> BITS) > 0) {
            return NULL;
        }
        return array_[k];
    }
    // REQUIRES "k" is in range "[0,2^BITS-1]".
    // REQUIRES "k" has been ensured before.
    //
    // Sets the value 'v' for key 'k'.
    void set(Number k, void* v) {
        array_[k] = v;
    }
};
// Two-level radix tree
template <int BITS>
class TCMalloc_PageMap2 {
private:
    // Put 32 entries in the root and (2^BITS)/32 entries in each leaf.
    static const int ROOT_BITS = 5;
    static const int ROOT_LENGTH = 1 << ROOT_BITS;
    static const int LEAF_BITS = BITS - ROOT_BITS;
    static const int LEAF_LENGTH = 1 << LEAF_BITS;
    // Leaf node
    struct Leaf {
  
            void* values[LEAF_LENGTH];
    };
    Leaf* root_[ROOT_LENGTH];
    void* (*allocator_)(size_t);
public:
    typedef uintptr_t Number;
    // Pointers to 32 child nodes
    // Memory allocator
    //explicit TCMalloc_PageMap2(void* (*allocator)(size_t)) {
    explicit TCMalloc_PageMap2()
    {
        //allocator_ = allocator;
        memset(root_, 0, sizeof(root_));
        PreallocateMoreMemory();
    }
    void* get(Number k) const {
        const Number i1 = k >> LEAF_BITS;
        const Number i2 = k & (LEAF_LENGTH - 1);
        if ((k >> BITS) > 0 || root_[i1] == NULL) {
            return NULL;
        }
        return root_[i1]->values[i2];
    }
    void set(Number k, void* v) {
        const Number i1 = k >> LEAF_BITS;
        const Number i2 = k & (LEAF_LENGTH - 1);
        ASSERT(i1 < ROOT_LENGTH);
        root_[i1]->values[i2] = v;
    }
    bool Ensure(Number start, size_t n) {
        for (Number key = start; key <= start + n - 1;) {
            const Number i1 = key >> LEAF_BITS;
            // Check for overflow
            if (i1 >= ROOT_LENGTH)
                return false;
            // Make 2nd level node if necessary
            if (root_[i1] == NULL) {
                Leaf* leaf = reinterpret_cast<Leaf*>((*allocator_)
                    (sizeof(Leaf)));
                if (leaf == NULL) return false;
                memset(leaf, 0, sizeof(*leaf));
                root_[i1] = leaf;
            }
            // Advance key past whatever is covered by this leaf node
            key = ((key >> LEAF_BITS) + 1) << LEAF_BITS;
        }
        return true;
    }
    void PreallocateMoreMemory() {
        // Allocate enough to keep track of all possible pages
        Ensure(0, 1 << BITS);
    }
};


// Three-level radix tree
template <int BITS>
class TCMalloc_PageMap3 {
private:
    // How many bits should we consume at each interior level
    static const int INTERIOR_BITS = (BITS + 2) / 3; // Round-up
    static const int INTERIOR_LENGTH = 1 << INTERIOR_BITS;
    // How many bits should we consume at leaf level
    static const int LEAF_BITS = BITS - 2 * INTERIOR_BITS;
    static const int LEAF_LENGTH = 1 << LEAF_BITS;
    // Interior node
    struct Node {
        Node* ptrs[INTERIOR_LENGTH];
    };
    // Leaf node
    struct Leaf {
        void* values[LEAF_LENGTH];
    };
    Node* root_;
    void* (*allocator_)(size_t);
    // Root of radix tree
    // Memory allocator
    Node* NewNode() {
        Node* result = reinterpret_cast<Node*>((*allocator_)(sizeof(Node)));
        if (result != NULL) {
            memset(result, 0, sizeof(*result));
        }
        return result;
    }
public:
    typedef uintptr_t Number;
    explicit TCMalloc_PageMap3(void* (*allocator)(size_t)) {
        allocator_ = allocator;
        root_ = NewNode();
    }
    void* get(Number k) const {
        const Number i1 = k >> (LEAF_BITS + INTERIOR_BITS);
        const Number i2 = (k >> LEAF_BITS) & (INTERIOR_LENGTH - 1);
        const Number i3 = k & (LEAF_LENGTH - 1);
        if ((k >> BITS) > 0 ||
            root_->ptrs[i1] == NULL || root_->ptrs[i1]->ptrs[i2] == NULL) {
            return NULL;
        }
        return reinterpret_cast<Leaf*>(root_->ptrs[i1]->ptrs[i2])->values[i3];
    }
    void set(Number k, void* v) {
        ASSERT(k >> BITS == 0);
        const Number i1 = k >> (LEAF_BITS + INTERIOR_BITS);
        const Number i2 = (k >> LEAF_BITS) & (INTERIOR_LENGTH - 1);
        const Number i3 = k & (LEAF_LENGTH - 1);
      

         
        reinterpret_cast<Leaf*>(root_->ptrs[i1]->ptrs[i2])->values[i3] = v;
    }
    bool Ensure(Number start, size_t n) {
        for (Number key = start; key <= start + n - 1;) {
            const Number i1 = key >> (LEAF_BITS + INTERIOR_BITS);
            const Number i2 = (key >> LEAF_BITS) & (INTERIOR_LENGTH - 1);
            // Check for overflow
            if (i1 >= INTERIOR_LENGTH || i2 >= INTERIOR_LENGTH)
                return false;
            // Make 2nd level node if necessary
            if (root_->ptrs[i1] == NULL) {
                Node* n = NewNode();
                if (n == NULL) return false;
                root_->ptrs[i1] = n;
            }
            // Make leaf node if necessary
            if (root_->ptrs[i1]->ptrs[i2] == NULL) {
               // Leaf* leaf = reinterpret_cast<Leaf*>((*allocator_) (sizeof(Leaf)));
              //  if (leaf == NULL) return false;
                static ObjectPool<Leaf> LeafPool;
                Leaf* leaf = (Leaf*)LeafPool.New();
                memset(leaf, 0, sizeof(*leaf));
                root_->ptrs[i1]->ptrs[i2] = reinterpret_cast<Node*>(leaf);
            }
            // Advance key past whatever is covered by this leaf node
            key = ((key >> LEAF_BITS) + 1) << LEAF_BITS;
        }
        return true;
    }
    void PreallocateMoreMemory() {
    }
};

ThreadCache.cpp

#include "ThreadCache.h"
#include "CentralCache.h"

void* ThreadCache::FetchFromCentralCache(size_t index, size_t size)
{
    //慢开始反馈调节算法
    //1. 最开始不会一次向central cache一次批量要太多,可能用不完
    //2. 如果你不要这个size大小内存需求,那么batchNum就会不断增长,直到上限(SizeClass::NumMoveSize(size))
    //3. size越大,一次向central cache要的batchNum就会越小
    //4. 如果size越小,一次向central cache要的batchNum就会越大
    //size_t batchNum = std::min(_freeLists[index].MaxSize(),SizeClass::NumMoveSize(size));
    size_t batchNum = _freeLists[index].MaxSize() <= SizeClass::NumMoveSize(size) ? _freeLists[index].MaxSize() : SizeClass::NumMoveSize(size);

    if (_freeLists[index].MaxSize() == batchNum)
    {
        _freeLists[index].MaxSize() += 1;
    }

    void* start = nullptr;
    void* end = nullptr;
    size_t actualNum = CentralCache::GetInstance()->FetchRangeObj(start, end, batchNum, size);
    assert(actualNum > 0);//有疑问

    if (actualNum == 1)
    {
        assert(start == end);
    }
    else
    {
        _freeLists[index].PushRange(NextObj(start), end, actualNum - 1);
    }
    return start;
}
void* ThreadCache::Allocate(size_t size)
{
    assert(size <= MAX_BYTES);
    size_t alignSize = SizeClass::RoundUp(size);
    size_t index = SizeClass::Index(size);

    if (!_freeLists[index].Empty())
    {
        return _freeLists[index].Pop();
    }
    else {
        return FetchFromCentralCache(index, alignSize);
    }
}

void ThreadCache::Deallocate(void* ptr, size_t size)
{
    assert(ptr);
    assert(size <= MAX_BYTES);

    //找对映射的自由链表桶,对象插入进入
    size_t index = SizeClass::Index(size);
    _freeLists[index].Push(ptr);
    //tcmalloc里面更复杂一点(考虑两点:1. 链表的长度,2. 内存的大小)

    //当链表长度大于一次批量申请的内存时就开始还一段list给central cache
    if (_freeLists[index].Size() >= _freeLists[index].MaxSize())
    {
        ListTooLong(_freeLists[index], size);
    }

}
void ThreadCache::ListTooLong(FreeList& list, size_t size)
{
    void* start = nullptr;
    void* end = nullptr;
    list.PopRange(start, end, list.MaxSize());

    CentralCache::GetInstance()->ReleaseListToSpans(start, size);
}

ThreadCache.h

#pragma once
#include "Common.h"
class ThreadCache
{
public:
    //申请和释放内存对象0
    void* Allocate(size_t size);

    void Deallocate(void* ptr, size_t size);

    void* FetchFromCentralCache(size_t index, size_t size);

    //释放对象时,链表过长时,回收内存回到中心缓存
    void ListTooLong(FreeList& list, size_t size);
private:
    FreeList _freeLists[NFREELISTS];
};

#ifdef _WIN32
static _declspec(thread) ThreadCache* pTLSThreadCache = nullptr;
#else
// macOS、linux下

//TLS thread local storage
static __thread ThreadCache* pTLSThreadCache = nullptr;
#endif

UnitTest.cpp(测试)

#include "ObjectPool.h"
#include "ConcurrentAlloc.h"

void Alloc1()
{
    for (size_t i = 0; i < 5; ++i)
    {
        void* ptr = ConcurrentAlloc(6);
    }
}

void Alloc2()
{
    for (size_t i = 0; i < 5; ++i)
    {
        void* ptr = ConcurrentAlloc(7);
    }
}


void TLSTest()
{
    std::thread t1(Alloc1);
    t1.join();

    std::thread t2(Alloc2);
    t2.join();
}
void TestConcurrentAlloc1()
{
    void* p1 = ConcurrentAlloc(6);
    void* p2 = ConcurrentAlloc(8);
    void* p3 = ConcurrentAlloc(1);
    void* p4 = ConcurrentAlloc(7);
    void* p5 = ConcurrentAlloc(8);
    void* p6 = ConcurrentAlloc(8);
    void* p7 = ConcurrentAlloc(8);



    cout << p1 << endl;
    cout << p2 << endl;
    cout << p3 << endl;
    cout << p4 << endl;
    cout << p5 << endl;

    ConcurrentFree(p1);
    ConcurrentFree(p2);
    ConcurrentFree(p3);
    ConcurrentFree(p4);
    ConcurrentFree(p5);
    ConcurrentFree(p6);
    ConcurrentFree(p7);


}
void TestConcurrentAlloc2()
{
    for (size_t i = 0; i < 1024; ++i)
    {
        void* p1 = ConcurrentAlloc(6);
        cout << p1 << endl;
    }
    void* p2 = ConcurrentAlloc(6);
    cout << p2 << endl;
}

void TestAddressShift()
{
    PAGE_ID id1 = 2000;
    PAGE_ID id2 = 2001;
    char* p1 = (char*)(id1 << PAGE_SHIFT);
    char* p2 = (char*)(id2 << PAGE_SHIFT);
    /*cout<< p1 <<endl;
    cout<< p2 <<endl;*///错误,由于地址没有null,所以需要强转static_cast<void*>
    cout << "p1 address: " << static_cast<void*>(p1) << endl;
    cout << "p2 address: " << static_cast<void*>(p2) << endl;
    while (p1 < p2)
    {
        cout << static_cast<void*>(p1) << ":" << ((PAGE_ID)p1 >> PAGE_SHIFT) << endl;
        p1 += 8;
    }

}


void MultiThreadAlloc1()
{
    std::vector<void*> v;
    for (size_t i = 0; i < 5; ++i)
    {
        void* ptr = ConcurrentAlloc(6);
        v.push_back(ptr);
    }

    for (auto e : v)
    {
        ConcurrentFree(e);
    }
}

void MultiThreadAlloc2()
{
    std::vector<void*> v;
    for (size_t i = 0; i < 5; ++i)
    {
        void* ptr = ConcurrentAlloc(6);
        v.push_back(ptr);
    }

    for (auto e : v)
    {
        ConcurrentFree(e);
    }
}
void TestMultiThread()
{
    std::thread t1(MultiThreadAlloc1);
    //t1.join();在这里回收是错误的
    std::thread t2(MultiThreadAlloc2);



    t1.join();
    t2.join();
}

void TestMultiThreadExtra()//测出了问题
{
    std::thread t1(TestConcurrentAlloc1);
    t1.join();
    std::thread t2(TestConcurrentAlloc1);
    t2.join();
}

void BigAlloc()
{
    void* p1 = ConcurrentAlloc(257 * 1024);
    ConcurrentFree(p1);

    void* p2 = ConcurrentAlloc(129 * 1024);
    ConcurrentFree(p2);
}

/*int main()
{
    //TLSTest();
    //TestConcurrentAlloc2();
    //TestAddressShift();
    //TestMultiThread();
    //TestConcurrentAlloc1();

    //TestMultiThreadExtra();
    TestConcurrentAlloc1();

    //BigAlloc();
    return 0;
}*/

CMakeLists.txt

cmake_minimum_required (VERSION 2.8)
 
project (高并发内存池)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED True)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

add_executable(a.out UnitTest.cpp  ThreadCache.cpp CentralCache.cpp PageCache.cpp BenchMark.cpp)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")

set(CMAKE_VERBOSE_MAKEFILE ON)

  • 作者测试过了,在Linux64位下可能存在报错,可能原因是Linux64位下的页表映射规则与32位不同。
  • CMakeList.txt只是为了构建测试文件(后面由于某种原因换了windows环境没再使用cmake了(没测试cmake))
  • BenchMark.cpp文件的加入可能会引发拷贝函数被delete的错误,建议更换编译器版本再尝试(后面发现是atomic<size_t>在不同的C++版本中构造方式不同,新版本C++中使用的是()构造,而不是使用=)
  • 除了tcmalloc还有ptmalloc、jemalloc等优秀的内存管理函数
  • 23
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱写代码的刚子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值