文章目录
前言
我们经常遇见各种池的概念,数据库连接池、消息队列池等。我们今天一起来看下缓存池的运用。
一、什么是线程安全?
线程安全是指多个线程同时访问一个对象时,不会产生任何不正确的结果或不一致的状态。在Python中,可以通过使用锁或其他同步机制来实现线程安全。
1. 看langchain-chatchat里的示例
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
2. 解读
在给定的示例中,ThreadSafeObject类使用了一个RLock对象(可重入锁)来实现线程安全。
1) 在类的初始化方法中,创建了一个RLock对象self._lock,这个锁用于控制对共享资源的访问。
2)在acquire方法中,使用self._lock.acquire()获取锁,确保在任何时候只有一个线程可以访问共享资源。
3) 在使用完共享资源后,使用self._lock.release()释放锁,允许其他线程访问共享资源。
4) 通过使用上下文管理器(@contextmanager),可以确保在使用共享资源时正确地获取和释放锁,即在yield之前获取锁,在yield之后释放锁。
5) 在其他方法中,如start_loading、finish_loading和wait_for_loading,通过使用self._loaded来同步多个线程的操作。
3. 作用和注意事项
通过使用锁和同步机制,ThreadSafeObject类确保了在多个线程同时访问时,共享资源不会被破坏或产生不一致的状态。这样就实现了线程安全。
需要注意的是,使用锁和同步机制会引入额外的开销,并可能导致性能下降。因此,在设计多线程应用程序时,需要权衡使用线程安全的成本和收益。
不明白@contextmanager可以看之前的文章,它是一个方便的注解,能自动生成with需要的两个方法
看langchain代码之前必备知识之 - with、yield和@contextmanager
二、缓存池
1. 缓存池代码示例(来自langchain-chatchat)
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
2. 解读
这个缓存池类用于缓存一些对象,以便在需要时可以快速获取。下面是对该类的详细解释:
-
__init__(self, cache_num: int = -1)
:构造函数,初始化缓存池。cache_num
参数表示缓存的最大数量,默认为-1,表示不限制缓存数量。_cache_num
变量用于保存缓存的最大数量,_cache
变量用于保存缓存对象,atomic
变量用于实现线程安全。这里self._cache = OrderedDict() 是创建了一个有序字典,顺序为插入顺序。这点很重要,为后面的清除旧缓存打下了基础。 -
keys(self) -> List[str]
:返回当前缓存池中所有缓存对象的键的列表。 -
_check_count(self)
:检查缓存的数量是否超过了最大限制,并根据需要移除最旧的缓存对象。也就是上面所述的有序字典类型self._cache 最早插入的元素 -
get(self, key: str) -> ThreadSafeObject
:根据指定的键获取缓存对象。如果对象存在,则等待对象加载完成后返回。如果对象不存在,则返回None。 -
set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject
:将指定的缓存对象添加到缓存池中,并检查缓存数量是否超过最大限制。返回添加的缓存对象。 -
pop(self, key: str = None) -> ThreadSafeObject
:从缓存池中移除指定键的缓存对象。如果未指定键,则移除最旧的缓存对象。返回被移除的缓存对象。 -
acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = "")
:根据指定的键获取缓存对象,并返回一个线程安全的对象。如果缓存对象不存在,则抛出运行时错误。如果缓存对象为线程安全的对象,则移动该对象到缓存池的末尾,并调用其acquire方法并返回结果。否则,直接返回缓存对象。思考下为什么要移动到末尾呢?原因是在缓存容量达到最高的时候,优先清除self._cache 最早插入的元素,而移动到末尾,这个key就相当于最新插入的元素 -
load_kb_embeddings(self, kb_name: str, embed_device: str = embedding_device(), default_embed_model: str = EMBEDDING_MODEL) -> Embeddings
:根据知识库名称从数据库中获取知识库的详细信息。根据知识库的嵌入模型选择加载嵌入数据的方法。如果嵌入模型在在线嵌入模型列表中,则返回一个适配器对象,该对象可以根据需要调用在线嵌入模型的方法。否则,调用embeddings_pool.load_embeddings
方法加载离线嵌入数据,并返回加载的嵌入数据对象。
3. 有序字典的知识补充
1)定义
Python中的有序字典是一种数据结构,它是字典(无序的键值对集合)的一个子类。与普通字典不同的是,有序字典会记住元素添加的顺序,并且可以按照这个顺序进行遍历。
有序字典的一些特点包括:
- 元素的插入顺序会被记住,不论之后是否进行了删除或修改操作。
- 可以通过键来访问元素,与普通字典相同。
- 可以使用
keys()
、values()
和items()
等方法来获取有序字典中的所有键、值或键值对,并且会按照插入顺序返回。 - 可以使用
popitem()
方法来弹出有序字典中的最后一个插入的元素。 - 可以使用
move_to_end()
方法来将指定键的元素移动到有序字典的末尾。
2)场景
有序字典的使用场景包括:
- 需要记录数据的插入顺序并保持顺序的情况,如日志记录。
- 需要按照元素插入顺序进行遍历的情况,如构建LRU缓存。
在实际使用中,如果需要有序字典的功能,可以直接使用Python的内置字典类型,因为它已经支持有序操作。如果需要在Python 3.7之前的版本中使用有序字典,可以通过导入collections
模块来使用。
3)示例
看下面的代码
from collections import OrderedDict
if __name__ == "__main__":
item = OrderedDict()
item['a'] = 1
item['b'] = 2
item['c'] = 3
print(item)
item.move_to_end('a')
print(item)
item.popitem(last=False)
print(item)
运行结果如下
总结
上的示例,实现了一个缓存池,这个缓存池依靠有序字典的特性,维持字典里key的数量,并将最新访问过的key往末尾移动,保证了热门key不容易被清理,冷门key优先清理的处理逻辑。同时也因为缓存池会被多线程运用,所以需要线程安全
补充
实际上我在测试上面的线程安全的时候,一直无法模拟出线程不安全的效果。
这是由于python存在GIL(全局解释锁)
Python中,每个线程的执行方式:
1)获取GIL
2)执行代码直到sleep或python解释器将其挂起
3)释放GIL
这就导致了在Python进程中,即使有多个线程,同一时间也是仅有一个线程在执行。