[Python算法优化] 递归算法记忆化技术、lru缓存技术及lru_cache源码分析

引言

在递归算法中,递归函数将迭代调用规模更小的递归函数。在每次调用过程中都消耗计算资源。以斐波那契数列为例:

def fib(n):
    if n == 1 or n == 2:
        return 1
    return fib(n - 1) + fib(n - 2)

在上述递归实现的代码中,以求斐波那契数列的第五项为例,有:

  fib(5)
= fib(4) + fib(3)
= fib(3) + fib(2) + fib(2) + fib(1)
= fib(2) + fib(1) + fib(2) + fib(2) + fib(1)
= 1 + 1 + 1 + 1 + 1
= 5

即第i项的值需要第i-1和第i-2项的值可以转化为:求第i项的值时需要用到所有第i项之前项数的值。

翻译成人话:求fib(5)需要fib(4)和fib(3),而fib(4)和fib(3)又需要fib(3),fib(2)和fib(1)。

在此过程中存在诸多重复计算,可以通过记忆化技术或缓存来实现

记忆化技术

思路:定义一个变量,在计算每一步fib()值时,将计算结果存入该变量,在后续计算中如果需要该值,则直接从变量中获取。

需要注意的是,比较简单的逻辑是先判断fib(x)是否存储在该变量中,即需要用到search方法,对于python的数据类型而言,字典是基于哈希散列实现的,其Search方法复杂度是O(1),而列表是O(n)。

cache = {} 


def fib(n):
    #在字典中查找
    if n in cache:
        return cache[n]
    if n == 1 or n == 2:
        value = 1
    elif n > 2:
        value = fib(n - 1) + fib(n - 2)
    #存入字典中
    cache[n] = value
    return value

使用记忆化技术可以大幅降低递归需要的时间,但是也会额外占用少量内存。

LRU缓存技术

Python内置的装饰器中提供了@lru_cache(maxsize, typed)。LRU为Least Recently Used,即最近最少使用。是一种基于Python的闭包且线程安全的实现。

LRU的机制为使用双向循环链表进行缓存,在缓存空间未满时直接增加节点,空间已满时删除最早的节点并添加当前的节点,在缓存命中时调整环形链表,将命中的节点移动到Root节点的右侧。

在使用@lru_cache时,有两个参数。maxsize表示最大缓存数量,即缓存该函数的多少结果值,默认128,设置为None时缓存将无限增长。typed表示是否按照数据类型做区分,如果typed=True,fib(10)和fib(10.0)将生成两个缓存值。

from functools import lru_cache


@lru_cache(maxsize=100, typed=False)
def fib(n):
    if n == 1 or n == 2:
        return 1
    elif n > 2:
        return fib(n - 1) + fib(n - 2)

此外,可以通过func.cache_info()方法查看缓存信息,通过func.cache_clear()方法清除缓存信息。

如果你的需求只是会使用,可以跳过下一部分。

源码

具体实现为,在加锁的前提下,对环形双链表进行添加缓存的操作。我在下文中的主要逻辑操作部分添加了注释。

需要注意的是,在此maxsize被分为0,None和else三种情况,对应三种情况分别写了3个wrapper函数。

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    sentinel = object()  # 未命中缓存时的默认值
    make_key = _make_key  # 定义缓存用的键值
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 #定义常用索引

    cache = {}  # 闭包缓存内容存储位置
    hits = misses = 0 
    full = False  # 判断缓存空间是否已满
    cache_get = cache.get  
    cache_len = cache.__len__  
    lock = RLock()  # 保证线程安全
    root = []  # 链表
    root[:] = [root, root, None, None]  # 建立环形双链表

    if maxsize == 0:
        def wrapper(*args, **kwds):
            # 当前没有缓存的操作
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:
        def wrapper(*args, **kwds):
            # 设置无限缓存后的操作
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:
        def wrapper(*args, **kwds):
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # 当链表不为空
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # 命中缓存
                    pass
                elif full:
                    # 缓存内容满的情况
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    root[KEY] = root[RESULT] = None
                    del cache[oldkey]
                    cache[key] = oldroot
                else:
                    # 添加新的缓存
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

作者对Python缓存机制的了解也不是很深入,如若文中有错误之处还请指正。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值