【Python】yield和cache的坑

估计很少有人像我这样经常用 yield 语句,所以似乎在我之前没有人发现这个问题

引入

我们知道,yield 语句可以方便地将一个函数当做一个可迭代对象使用

def f():
	yield 1
	yield 4

for i in f():
	print(i)
	# 1
	# 4

所以我们可以使用 yield 语句实现比较复杂的迭代(下面的代码不要求看懂)。codewars题目链接

def permutations(s):
    def f(s):
        if len(s) == 0:
            yield ""
            return
        if len(s) == 1:
            yield s
            return
        for ss in f(s[1:]):
            yield s[0]+ss
        for ss in f(s[:-1]):
            yield s[-1]+ss
        for i in range(1, len(s)-1):
            for ss in f(s[:i]+s[i+1:]):
                yield s[i]+ss
        return
    
    # Code Away!
    return list(set(f(s)))

if __name__ == "__main__":
    print(permutations("abcd")) # ['dabc', 'badc', 'dbac', 'dacb', 'bdac', 'cabd', 'abcd', 'cbda', 'acbd', 'dbca', 'bcda', 'cadb', 'cbad', 'acdb', 'bcad', 'adcb', 'adbc', 'cdab', 'bacd', 'dcab', 'abdc', 'bdca', 'cdba', 'dcba']
    print(permutations("aabb")) #['aabb', 'bbaa', 'abba', 'baba', 'abab', 'baab']

问题

但我们不能将 yield 语句和 python 自带的 cache 放在一起用。下面是一个例子。

from functools import cache

@cache
def f():
    yield 1
    yield 2
    print("this line will be processed for only once")
    return 

for a in f():
    print(a)


for a in f():
    print(a)
"""
1
2
this line will be processed for only once
"""

我们期望代码输出如下。

1
2
this line will be processed for only once
1
2
this line will be processed for only once

我们可以发现,对于加上缓存装饰器并且使用 yield 语句的函数来说,对于同一个输入,函数只会第一次执行的时候正常运行,而在第二次之后执行的时候不会正常执行。

我们用上述例子中的 f() 函数举例,该函数第一次返回后,该函数的返回值被记录在缓存中。在第二次及之后的运行中,f() 函数不会被调用,所以 yield 语句不会被执行,所以不会达到预期效果。

即使一个函数没有 return 语句,这样的异常也会出现,因为 Python 语法中会默认让没有返回值的函数返回 None

所以有什么解决方法吗

我们可以自己实现一个 yield_cache 装饰器

import functools

def yield_cache(f):
    @functools.cache
    def a(*args, **kwargs):
        return [elem for elem in f(*args, **kwargs)]

    def decorating_function(*args, **kwargs):
        for elem in a(*args, **kwargs):
            yield elem

    return decorating_function

# 此行以下是测试代码
if __name__ == "__main__":
    called = 0
    test1 = "abcd"
    test2 = "aabb"
    test3 = "aaabbb"
    def permutations(s):
        def f(s):
            global called
            called += 1
            if len(s) == 0:
                yield ""
                return
            if len(s) == 1:
                yield s
                return
            for ss in f(s[1:]):
                yield s[0]+ss
            for ss in f(s[:-1]):
                yield s[-1]+ss
            for i in range(1, len(s)-1):
                for ss in f(s[:i]+s[i+1:]):
                    yield s[i]+ss
            return
        
        # Code Away!
        return list(set(f(s)))
    
    print(permutations(test1))
    print(called)
    called = 0
    print(permutations(test2))
    print(called)
    called = 0
    print(permutations(test3))
    print(called)
    called = 0

    def permutations(s):
        @yield_cache
        def f(s):
            global called
            called += 1
            if len(s) == 0:
                yield ""
                return
            if len(s) == 1:
                yield s
                return
            for ss in f(s[1:]):
                yield s[0]+ss
            for ss in f(s[:-1]):
                yield s[-1]+ss
            for i in range(1, len(s)-1):
                for ss in f(s[:i]+s[i+1:]):
                    yield s[i]+ss
            return
        
        # Code Away!
        return list(set(f(s)))
    
    print(permutations(test1))
    print(called)
    called = 0
    print(permutations(test2))
    print(called)
    called = 0
    print(permutations(test3))
    print(called)
    called = 0

不过说起来这个装饰器可能有些低效,因为函数 a() 内部是使用 list 类型来存储数据的。要是可以用 functools._lru_cache_wrapper() 函数就好了,可能会快一些。

下面是测试结果。其中每个数字表示函数 f() 被调用次数。

before using yield_cache_wrapper
['abdc', 'badc', 'cadb', 'cdba', 'dbca', 'bcad', 'cbda', 'dcba', 'cabd', 'bcda', 'acbd', 'bdac', 'cbad', 'dacb', 'dabc', 'cdab', 'abcd', 'bacd', 'acdb', 'bdca', 'dcab', 'dbac', 'adcb', 'adbc']
41
['aabb', 'abba', 'abab', 'baab', 'baba', 'bbaa']
41
['abbbaa', 'ababba', 'baaabb', 'babaab', 'abbaab', 'babbaa', 'bbbaaa', 'bbaaab', 'abaabb', 'bbabaa', 'abbaba', 'ababab', 'aaabbb', 'bbaaba', 'baabab', 'baabba', 'aabbab', 'bababa', 'aababb', 'aabbba']
1237
after using yield_cache_wrapper
['abdc', 'badc', 'cadb', 'cdba', 'dbca', 'bcad', 'cbda', 'dcba', 'cabd', 'bcda', 'acbd', 'bdac', 'cbad', 'dacb', 'dabc', 'cdab', 'abcd', 'bacd', 'acdb', 'bdca', 'dcab', 'dbac', 'adcb', 'adbc']
15
['aabb', 'abba', 'abab', 'baab', 'baba', 'bbaa']
8
['abbbaa', 'ababba', 'baaabb', 'babaab', 'abbaab', 'babbaa', 'bbbaaa', 'bbaaab', 'abaabb', 'bbabaa', 'abbaba', 'ababab', 'aaabbb', 'bbaaba', 'baabab', 'baabba', 'aabbab', 'bababa', 'aababb', 'aabbba']
15

为什么我们想要把 yield 语句和缓存一起使用

比如上面 codewars 代码例子。假设我们传入 f(s) 的字符串长度非常长,那么我们自然希望这个函数每次的迭代值能够缓存下来。

为什么想到使用缓存

假设我们想要使用递归的方法,计算斐波那契数列的第 900 项。那么我们会发现,随着项数的增加,计算斐波那契数列的时间呈现指数级增长。

def fib(n):
	if n<=1:
		return 1
	else:
		return fib(n-1)+fib(n-2)

我们可以试一下为什么。下列代码会帮我们计算 fib() 函数被调用次数。

a=0

def fib(n):
    global a
    a += 1
    if n<=1:    
        return 1
    else:
        return fib(n-1)+fib(n-2)
    
if __name__ == "__main__":
    a = 0
    print(fib(10))
    print(a) 
    a = 0
    print(fib(20))
    print(a)
    a = 0
    print(fib(30))
    print(a)

我们可以看到结果。计算 fib(10) 调用了 177fib() 函数,fib(20)21897 次,fib(30)2692537 次。

但如果我们将 fib() 函数用 @functools.cache 修饰,那么次数会大幅减少。

a=0
import functools

@functools.cache
def fib(n):
    global a
    a += 1
    if n<=1:    
        return 1
    else:
        return fib(n-1)+fib(n-2)
    
if __name__ == "__main__":
    a = 0
    print(fib(10))
    print(a)
    a = 0
    print(fib(20))
    print(a)
    a = 0
    print(fib(30))
    print(a)

结果为

89
11
10946
10
1346269
10
  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值