估计很少有人像我这样经常用 yield
语句,所以似乎在我之前没有人发现这个问题
引入
我们知道,yield
语句可以方便地将一个函数当做一个可迭代对象使用
def f():
yield 1
yield 4
for i in f():
print(i)
# 1
# 4
所以我们可以使用 yield
语句实现比较复杂的迭代(下面的代码不要求看懂)。codewars题目链接
def permutations(s):
def f(s):
if len(s) == 0:
yield ""
return
if len(s) == 1:
yield s
return
for ss in f(s[1:]):
yield s[0]+ss
for ss in f(s[:-1]):
yield s[-1]+ss
for i in range(1, len(s)-1):
for ss in f(s[:i]+s[i+1:]):
yield s[i]+ss
return
# Code Away!
return list(set(f(s)))
if __name__ == "__main__":
print(permutations("abcd")) # ['dabc', 'badc', 'dbac', 'dacb', 'bdac', 'cabd', 'abcd', 'cbda', 'acbd', 'dbca', 'bcda', 'cadb', 'cbad', 'acdb', 'bcad', 'adcb', 'adbc', 'cdab', 'bacd', 'dcab', 'abdc', 'bdca', 'cdba', 'dcba']
print(permutations("aabb")) #['aabb', 'bbaa', 'abba', 'baba', 'abab', 'baab']
问题
但我们不能将 yield
语句和 python 自带的 cache
放在一起用。下面是一个例子。
from functools import cache
@cache
def f():
yield 1
yield 2
print("this line will be processed for only once")
return
for a in f():
print(a)
for a in f():
print(a)
"""
1
2
this line will be processed for only once
"""
我们期望代码输出如下。
1
2
this line will be processed for only once
1
2
this line will be processed for only once
我们可以发现,对于加上缓存装饰器并且使用 yield
语句的函数来说,对于同一个输入,函数只会第一次执行的时候正常运行,而在第二次之后执行的时候不会正常执行。
我们用上述例子中的 f()
函数举例,该函数第一次返回后,该函数的返回值被记录在缓存中。在第二次及之后的运行中,f()
函数不会被调用,所以 yield
语句不会被执行,所以不会达到预期效果。
即使一个函数没有
return
语句,这样的异常也会出现,因为 Python 语法中会默认让没有返回值的函数返回None
。
所以有什么解决方法吗
我们可以自己实现一个 yield_cache
装饰器
import functools
def yield_cache(f):
@functools.cache
def a(*args, **kwargs):
return [elem for elem in f(*args, **kwargs)]
def decorating_function(*args, **kwargs):
for elem in a(*args, **kwargs):
yield elem
return decorating_function
# 此行以下是测试代码
if __name__ == "__main__":
called = 0
test1 = "abcd"
test2 = "aabb"
test3 = "aaabbb"
def permutations(s):
def f(s):
global called
called += 1
if len(s) == 0:
yield ""
return
if len(s) == 1:
yield s
return
for ss in f(s[1:]):
yield s[0]+ss
for ss in f(s[:-1]):
yield s[-1]+ss
for i in range(1, len(s)-1):
for ss in f(s[:i]+s[i+1:]):
yield s[i]+ss
return
# Code Away!
return list(set(f(s)))
print(permutations(test1))
print(called)
called = 0
print(permutations(test2))
print(called)
called = 0
print(permutations(test3))
print(called)
called = 0
def permutations(s):
@yield_cache
def f(s):
global called
called += 1
if len(s) == 0:
yield ""
return
if len(s) == 1:
yield s
return
for ss in f(s[1:]):
yield s[0]+ss
for ss in f(s[:-1]):
yield s[-1]+ss
for i in range(1, len(s)-1):
for ss in f(s[:i]+s[i+1:]):
yield s[i]+ss
return
# Code Away!
return list(set(f(s)))
print(permutations(test1))
print(called)
called = 0
print(permutations(test2))
print(called)
called = 0
print(permutations(test3))
print(called)
called = 0
不过说起来这个装饰器可能有些低效,因为函数 a()
内部是使用 list
类型来存储数据的。要是可以用 functools._lru_cache_wrapper()
函数就好了,可能会快一些。
下面是测试结果。其中每个数字表示函数 f()
被调用次数。
before using yield_cache_wrapper
['abdc', 'badc', 'cadb', 'cdba', 'dbca', 'bcad', 'cbda', 'dcba', 'cabd', 'bcda', 'acbd', 'bdac', 'cbad', 'dacb', 'dabc', 'cdab', 'abcd', 'bacd', 'acdb', 'bdca', 'dcab', 'dbac', 'adcb', 'adbc']
41
['aabb', 'abba', 'abab', 'baab', 'baba', 'bbaa']
41
['abbbaa', 'ababba', 'baaabb', 'babaab', 'abbaab', 'babbaa', 'bbbaaa', 'bbaaab', 'abaabb', 'bbabaa', 'abbaba', 'ababab', 'aaabbb', 'bbaaba', 'baabab', 'baabba', 'aabbab', 'bababa', 'aababb', 'aabbba']
1237
after using yield_cache_wrapper
['abdc', 'badc', 'cadb', 'cdba', 'dbca', 'bcad', 'cbda', 'dcba', 'cabd', 'bcda', 'acbd', 'bdac', 'cbad', 'dacb', 'dabc', 'cdab', 'abcd', 'bacd', 'acdb', 'bdca', 'dcab', 'dbac', 'adcb', 'adbc']
15
['aabb', 'abba', 'abab', 'baab', 'baba', 'bbaa']
8
['abbbaa', 'ababba', 'baaabb', 'babaab', 'abbaab', 'babbaa', 'bbbaaa', 'bbaaab', 'abaabb', 'bbabaa', 'abbaba', 'ababab', 'aaabbb', 'bbaaba', 'baabab', 'baabba', 'aabbab', 'bababa', 'aababb', 'aabbba']
15
为什么我们想要把 yield
语句和缓存一起使用
比如上面 codewars 代码例子。假设我们传入 f(s)
的字符串长度非常长,那么我们自然希望这个函数每次的迭代值能够缓存下来。
为什么想到使用缓存
假设我们想要使用递归的方法,计算斐波那契数列的第 900 项。那么我们会发现,随着项数的增加,计算斐波那契数列的时间呈现指数级增长。
def fib(n):
if n<=1:
return 1
else:
return fib(n-1)+fib(n-2)
我们可以试一下为什么。下列代码会帮我们计算 fib()
函数被调用次数。
a=0
def fib(n):
global a
a += 1
if n<=1:
return 1
else:
return fib(n-1)+fib(n-2)
if __name__ == "__main__":
a = 0
print(fib(10))
print(a)
a = 0
print(fib(20))
print(a)
a = 0
print(fib(30))
print(a)
我们可以看到结果。计算 fib(10)
调用了 177
次 fib()
函数,fib(20)
是21897
次,fib(30)
是 2692537
次。
但如果我们将 fib()
函数用 @functools.cache
修饰,那么次数会大幅减少。
a=0
import functools
@functools.cache
def fib(n):
global a
a += 1
if n<=1:
return 1
else:
return fib(n-1)+fib(n-2)
if __name__ == "__main__":
a = 0
print(fib(10))
print(a)
a = 0
print(fib(20))
print(a)
a = 0
print(fib(30))
print(a)
结果为
89
11
10946
10
1346269
10