当调用的函数是一些耗时操作且结果只与参数有关时,为了提高响应速率,可以使用这个函数作为装饰器,它本质是一个缓存池,一次新的计算后会缓存这次计算结果,当下次被调用时就有可能不用再计算,当然只是有可能,具体看后面的代码解读。
假如有一个函数需要根据传入的字符串去硬盘上读取相应的文件文本并返回给用户,则可以这样使用:
from functools import lru_cache
@lru_cache
def readFile(filename:str):
return open(filename)
但是lru_cache对被装饰函数的参数有限制,下面会看到。
ru_cache的函数代码:
def lru_cache(maxsize=128, typed=False):
if isinstance(maxsize, int):
if maxsize < 0:
maxsize = 0
elif maxsize is not None:
raise TypeError('Expected maxsize to be an integer or None')
def decorating_function(user_function):
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
return update_wrapper(wrapper, user_function)
return decorating_function
首先,这个函数可以指定两个参数:
- maxsize,这个参数是个非负整数或者None,表示缓存的最大数量;0表示压根不缓存,此时这个装饰器也就失去了缓存的意义;None表示把每次新的计算结果都缓存,没有数量限制,直到内存撑爆;
- typed时个bool值,它表示是否要对参数的类型进行严格区分;比如假设参数是1与1.0,两者在数学上是等值的但在计算机中类型并不相同;当typed=True时,假设函数名叫f,则f(1)与f(1.0)的计算会被区分并分别保存;而当typed=True时,f(1)与f(1.0)没区别;这样用户就可以根据自己逻辑的需求来确定是否需要区分这两种类型;
它返回的是decorating_function函数,而该函数内部又调用了_lru_cache_wrapper函数,我们看看这个函数的实现:
def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
# Constants shared by all lru cache instances:
sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
cache = {}
hits = misses = 0
full = False
cache_get = cache.get # bound method to lookup a key or return None
cache_len = cache.__len__ # get cache size without calling len()
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
if maxsize == 0:
def wrapper(*args, **kwds):
# No caching -- just a statistics update
nonlocal misses
misses += 1
result = user_function(*args, **kwds)
return result
elif maxsize is None:
def wrapper(*args, **kwds):
# Simple caching without ordering or size limit
nonlocal hits, misses
key = make_key(args, kwds, typed)
result = cache_get(key, sentinel)
if result is not sentinel:
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
cache[key] = result
return result
else:
def wrapper(*args, **kwds):
# Size limited caching that tracks accesses by recency
nonlocal root, hits, misses, full
key = make_key(args, kwds, typed)
with lock:
link = cache_get(key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = root[PREV]
last[NEXT] = root[PREV] = link
link[PREV] = last
link[NEXT] = root
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
with lock:
if key in cache:
# Getting here means that this same key was added to the
# cache while the lock was released. Since the link
# update is already done, we need only return the
# computed result and update the count of misses.
pass
elif full:
# Use the old root to store the new key and result.
oldroot = root
oldroot[KEY] = key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
# Keep a reference to the old key and old result to
# prevent their ref counts from going to zero during the
# update. That will prevent potentially arbitrary object
# clean-up code (i.e. __del__) from running while we're
# still adjusting the links.
root = oldroot[NEXT]
oldkey = root[KEY]
oldresult = root[RESULT]
root[KEY] = root[RESULT] = None
# Now update the cache dictionary.
del cache[oldkey]
# Save the potentially reentrant cache[key] assignment
# for last, after the root and links have been put in
# a consistent state.
cache[key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = root[PREV]
link = [last, root, key, result]
last[NEXT] = root[PREV] = cache[key] = link
# Use the cache_len bound method instead of the len() function
# which could potentially be wrapped in an lru_cache itself.
full = (cache_len() >= maxsize)
return result
def cache_info():
"""Report cache statistics"""
with lock:
return _CacheInfo(hits, misses, maxsize, cache_len())
def cache_clear():
"""Clear the cache and cache statistics"""
nonlocal hits, misses, full
with lock:
cache.clear()
root[:] = [root, root, None, None]
hits = misses = 0
full = False
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper
这个函数真够长的,它做了几件事情:
- 指定了如何根据函数传入的参数计算key(make_key = _make_key),因为它要缓存计算结果,就必须有一个方法验证分这批参数是否已经被计算过;
- 构建函数专用的锁(lock = RLock()),以保证多线程情况下的数据安全;
- 构建一个环形双向链表(root = []),这个主要是因为用户可能对缓存大小有限制,而lru是Least-recently-used cache decorator的缩写,因此当缓存数量达到上限时,需要将使用最少的那个key删除,这个双向链表就是实现这个功能,它会将最近一次调用的参数结果存储在链表最后,于是用的少的或者压根没被再次调用过的就放在了链表前端,需要时被优先删除;
- 给被装饰的函数赋予额外的两个函数属性,一个用来展示当前缓存的情况,一个用来清除所有缓存;
计算key
代码中使用了make_key变量,它保存了_make_key,这是一个内部函数,可以看看它的实现:
def _make_key(args, kwds, typed,
kwd_mark = (object(),),
fasttypes = {int, str},
tuple=tuple, type=type, len=len):
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kwds:
key += tuple(type(v) for v in kwds.values())
elif len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
它又做了几件事情:
- 构建一个tuple,tuple本身就是一个元组,然后将kwds(这里是一个字典)键值对追加到args元组后边(这里似乎应该用到了新版本字典的有序性,否则会导致传入的关键字参数在计算key时不一致,而py2中并没有这个lru_cache,如果要实现的话,必须要手动排序);
- 检查是否需要区分类型,当需要区分类型时,就将每个参数的类型也按需追加到这个元组的后边;
- 如果参数只有一个且其类型是str或者int,就直接返回值本身,因为值本身可以作为键;
- 用上面的元组构建一个_HashedSeq对象,用来计算这个元组的哈希。
先看_HashedSeq的实现:
class _HashedSeq(list):
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
它将传入的元组(虽然没有明确说明,这个参数可以为任何可被哈希的对象,但在lru_cache的实现中,这里只能是那些参数以及类型够成的元组)参数计算出哈希值,然后将哈希值保存,节省计算时间。
构建锁
这个没啥可说的,lock = RLock(),创建一个可读锁,在处理cache时,用来保证线程数据不会冲突;
构建环形双向链表
先看声明代码
root = []
root[:] = [root, root, None, None]
root表示整个环形链表的起始节点,它本身不存储缓存,从它的下一个节点开始存有效值;链表中每个元素都是一个四元素list:
- 0,下个节点位置;
- 1,上个节点位置;
- 2,key;
- 3,值;
主要的逻辑也就在与整个环形链表打交道,根据指定的缓存池大小的不同,返回的wrapper也就不同。
- 缓存大小为0(即maxsize == 0)
此时缓存失去作用,wrapper就是一个基本没什么装饰作用的函数计算; - 不限制缓存大小(即maxsize is None),这时候就没双向链表什么事儿了,wrapper将首先到cache中查看本次参数是否之前计算过,如果计算过,就直接从cache取出值返回;如果没有计算过,就计算一遍,将结果缓存再返回;
- 设置了缓存大小,这个比较麻烦,因为除了跟cache打交道外,还涉及到对环形链表的调整。如果这不是头一次计算,就要将它对应的链表节点放在链表最后;如果是头一次计算,那除了计算外,还要判断链表是否满了,如果满了,就将根节点调整到根节点对应的下个节点,而用原来根节点存储新加入的这个节点的计算结果,这样可以少一次内存分配和回收;
给函数增加额外属性
- cache_info(),这个方法用来返回一个有名元组,展示当前缓存池的情况:
_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
- cache_clear(),清空缓存池以及双向链表复原;
这个函数与partial一样存在其实未被执行的问题:
try:
from _functools import _lru_cache_wrapper
except ImportError:
pass
不过无所谓啦。
使用lru_cache时有两点需要注意:
1 传入的参数被计算了hash值,也就是说参数的值必须是可哈希的,如果有一个不能哈希,那程序会报错的:
from functools import lru_cache
@lru_cache(100)
def sum(x):
sum = 0
for e in x:
sum += e
return sum
print(sum((1,2,3))) # 执行成功
print(sum([1,2,3])) # 报错,list cannot hash
2 计算哈希跟传参顺序有关,看下面的例子:
from functools import lru_cache
@lru_cache(100)
def p(x,y):
print(x,y)
p(x = 10,y = 20)
print(p.cache_info()) # hits为1
p(y = 20,x = 10)
print(p.cache_info()) # hits为2,说明与上次的调用被区分了出来。
所以用户需要确保传参顺序一致才能保证命中率。