一起看langchain-chatchat代码 - python的线程安全和缓存池


前言

我们经常遇见各种池的概念,数据库连接池、消息队列池等。我们今天一起来看下缓存池的运用。


一、什么是线程安全?

线程安全是指多个线程同时访问一个对象时,不会产生任何不正确的结果或不一致的状态。在Python中,可以通过使用锁或其他同步机制来实现线程安全。

1. 看langchain-chatchat里的示例

class ThreadSafeObject:
    def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
        self._obj = obj
        self._key = key
        self._pool = pool
        self._lock = threading.RLock()
        self._loaded = threading.Event()

    def __repr__(self) -> str:
        cls = type(self).__name__
        return f"<{cls}: key: {self.key}, obj: {self._obj}>"

    @property
    def key(self):
        return self._key

    @contextmanager
    def acquire(self, owner: str = "", msg: str = "") -> FAISS:
        owner = owner or f"thread {threading.get_native_id()}"
        try:
            self._lock.acquire()
            if self._pool is not None:
                self._pool._cache.move_to_end(self.key)
            if log_verbose:
                logger.info(f"{owner} 开始操作:{self.key}{msg}")
            yield self._obj
        finally:
            if log_verbose:
                logger.info(f"{owner} 结束操作:{self.key}{msg}")
            self._lock.release()

    def start_loading(self):
        self._loaded.clear()

    def finish_loading(self):
        self._loaded.set()

    def wait_for_loading(self):
        self._loaded.wait()

    @property
    def obj(self):
        return self._obj

    @obj.setter
    def obj(self, val: Any):
        self._obj = val

2. 解读

在给定的示例中,ThreadSafeObject类使用了一个RLock对象(可重入锁)来实现线程安全。

1) 在类的初始化方法中,创建了一个RLock对象self._lock,这个锁用于控制对共享资源的访问。
2)在acquire方法中,使用self._lock.acquire()获取锁,确保在任何时候只有一个线程可以访问共享资源。
3) 在使用完共享资源后,使用self._lock.release()释放锁,允许其他线程访问共享资源。
4) 通过使用上下文管理器(@contextmanager),可以确保在使用共享资源时正确地获取和释放锁,即在yield之前获取锁,在yield之后释放锁。
5) 在其他方法中,如start_loading、finish_loading和wait_for_loading,通过使用self._loaded来同步多个线程的操作。

3. 作用和注意事项

通过使用锁和同步机制,ThreadSafeObject类确保了在多个线程同时访问时,共享资源不会被破坏或产生不一致的状态。这样就实现了线程安全。

需要注意的是,使用锁和同步机制会引入额外的开销,并可能导致性能下降。因此,在设计多线程应用程序时,需要权衡使用线程安全的成本和收益。

不明白@contextmanager可以看之前的文章,它是一个方便的注解,能自动生成with需要的两个方法
看langchain代码之前必备知识之 - with、yield和@contextmanager

二、缓存池

1. 缓存池代码示例(来自langchain-chatchat)

class CachePool:
    def __init__(self, cache_num: int = -1):
        self._cache_num = cache_num
        self._cache = OrderedDict()
        self.atomic = threading.RLock()

    def keys(self) -> List[str]:
        return list(self._cache.keys())

    def _check_count(self):
        if isinstance(self._cache_num, int) and self._cache_num > 0:
            while len(self._cache) > self._cache_num:
                self._cache.popitem(last=False)

    def get(self, key: str) -> ThreadSafeObject:
        if cache := self._cache.get(key):
            cache.wait_for_loading()
            return cache

    def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
        self._cache[key] = obj
        self._check_count()
        return obj

    def pop(self, key: str = None) -> ThreadSafeObject:
        if key is None:
            return self._cache.popitem(last=False)
        else:
            return self._cache.pop(key, None)

    def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
        cache = self.get(key)
        if cache is None:
            raise RuntimeError(f"请求的资源 {key} 不存在")
        elif isinstance(cache, ThreadSafeObject):
            self._cache.move_to_end(key)
            return cache.acquire(owner=owner, msg=msg)
        else:
            return cache

    def load_kb_embeddings(
            self,
            kb_name: str,
            embed_device: str = embedding_device(),
            default_embed_model: str = EMBEDDING_MODEL,
    ) -> Embeddings:
        from server.db.repository.knowledge_base_repository import get_kb_detail
        from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter

        kb_detail = get_kb_detail(kb_name)
        embed_model = kb_detail.get("embed_model", default_embed_model)

        if embed_model in list_online_embed_models():
            return EmbeddingsFunAdapter(embed_model)
        else:
            return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)

2. 解读

这个缓存池类用于缓存一些对象,以便在需要时可以快速获取。下面是对该类的详细解释:

  • __init__(self, cache_num: int = -1):构造函数,初始化缓存池。cache_num参数表示缓存的最大数量,默认为-1,表示不限制缓存数量。_cache_num变量用于保存缓存的最大数量,_cache变量用于保存缓存对象,atomic变量用于实现线程安全。这里self._cache = OrderedDict() 是创建了一个有序字典,顺序为插入顺序。这点很重要,为后面的清除旧缓存打下了基础。

  • keys(self) -> List[str]:返回当前缓存池中所有缓存对象的键的列表。

  • _check_count(self):检查缓存的数量是否超过了最大限制,并根据需要移除最旧的缓存对象。也就是上面所述的有序字典类型self._cache 最早插入的元素

  • get(self, key: str) -> ThreadSafeObject:根据指定的键获取缓存对象。如果对象存在,则等待对象加载完成后返回。如果对象不存在,则返回None。

  • set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:将指定的缓存对象添加到缓存池中,并检查缓存数量是否超过最大限制。返回添加的缓存对象。

  • pop(self, key: str = None) -> ThreadSafeObject:从缓存池中移除指定键的缓存对象。如果未指定键,则移除最旧的缓存对象。返回被移除的缓存对象。

  • acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):根据指定的键获取缓存对象,并返回一个线程安全的对象。如果缓存对象不存在,则抛出运行时错误。如果缓存对象为线程安全的对象,则移动该对象到缓存池的末尾,并调用其acquire方法并返回结果。否则,直接返回缓存对象。思考下为什么要移动到末尾呢?原因是在缓存容量达到最高的时候,优先清除self._cache 最早插入的元素,而移动到末尾,这个key就相当于最新插入的元素

  • load_kb_embeddings(self, kb_name: str, embed_device: str = embedding_device(), default_embed_model: str = EMBEDDING_MODEL) -> Embeddings:根据知识库名称从数据库中获取知识库的详细信息。根据知识库的嵌入模型选择加载嵌入数据的方法。如果嵌入模型在在线嵌入模型列表中,则返回一个适配器对象,该对象可以根据需要调用在线嵌入模型的方法。否则,调用embeddings_pool.load_embeddings方法加载离线嵌入数据,并返回加载的嵌入数据对象。

3. 有序字典的知识补充

1)定义

Python中的有序字典是一种数据结构,它是字典(无序的键值对集合)的一个子类。与普通字典不同的是,有序字典会记住元素添加的顺序,并且可以按照这个顺序进行遍历。

有序字典的一些特点包括:

  • 元素的插入顺序会被记住,不论之后是否进行了删除或修改操作。
  • 可以通过键来访问元素,与普通字典相同。
  • 可以使用keys()values()items()等方法来获取有序字典中的所有键、值或键值对,并且会按照插入顺序返回。
  • 可以使用popitem()方法来弹出有序字典中的最后一个插入的元素。
  • 可以使用move_to_end()方法来将指定键的元素移动到有序字典的末尾。

2)场景

有序字典的使用场景包括:

  • 需要记录数据的插入顺序并保持顺序的情况,如日志记录。
  • 需要按照元素插入顺序进行遍历的情况,如构建LRU缓存。

在实际使用中,如果需要有序字典的功能,可以直接使用Python的内置字典类型,因为它已经支持有序操作。如果需要在Python 3.7之前的版本中使用有序字典,可以通过导入collections模块来使用。

3)示例

看下面的代码

from collections import OrderedDict

if __name__ == "__main__":
    item = OrderedDict()
    item['a'] = 1
    item['b'] = 2
    item['c'] = 3
    print(item)

    item.move_to_end('a')

    print(item)
    item.popitem(last=False)
    print(item)
    

运行结果如下
运行结果


总结

上的示例,实现了一个缓存池,这个缓存池依靠有序字典的特性,维持字典里key的数量,并将最新访问过的key往末尾移动,保证了热门key不容易被清理,冷门key优先清理的处理逻辑。同时也因为缓存池会被多线程运用,所以需要线程安全

补充

实际上我在测试上面的线程安全的时候,一直无法模拟出线程不安全的效果。

这是由于python存在GIL(全局解释锁)

Python中,每个线程的执行方式:

1)获取GIL
2)执行代码直到sleep或python解释器将其挂起
3)释放GIL

这就导致了在Python进程中,即使有多个线程,同一时间也是仅有一个线程在执行。

  • 34
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值