Python递归优化避免重复递归:递归程序设计艺术(6)

如果不加以优化的话,递归很容易出现重复计算的问题。比如前面计算斐波那契数列,根据公式有F(n) = F(n-1) + F(n-2)。这意味着为了计算F(8),必须计算F(7)和F(6)。而为了计算F(7),必须计算F(6)和F(5),......。这里F(6)就被计算了两次。一般地,递归程序越靠近边界,重复计算的次数就会呈指数增加。当求F(36)时,电脑已经完全僵死,没有反应了。

那怎么解决这个问题呢?规范的方法是这样的:

  1. 要保证递归程序除了参数之外没有读写任何外部数据。这是为了保证程序的优化只在程序内部发生作用,不会影响到外部环境。这个特点称为函数不变性(Immutable),与数据不变性相对应。函数不变性的另一种理解是,只要参数相同,函数总是返回相同的结果。比如随机数发生函数numpy.random.randint()以及日期函数time.time()就不具有不变性。一般地,如果函数是可变的,可以把相关的外部数据引入参数列表中就能解决这个问题。另外,如果函数只是读取外部数据,没有改写它,那函数也是不可变的。
  2. 在函数的参数中增加一个字典型(dict)的参数d,作用是保存递归调用的中间结果。字典是一种能够根据键获取值的数据类型,Python中的字典相当于Java中的Map。当发生重复调用时,从d中直接获取结果即可,不必进行重复的递归调用。
  3. 把函数当前所有参数集中起来构成一个元组t,在递归假设和递归推导之前看看t是否在d中存在,如果存在,说明本次调用是一个重复调用,直接返回d[t]的结果即可。
  4. 如果d[t]不存在,则进行正常的递归假设和递归推导,递归假设(即递归调用)时,不要忘记带上参数d
  5. 最后,在所有可能退出函数的地方(比如 return语句前或者函数体最后一个语句之后)以t为键把函数的返回值存入字典d中。

综上所述,代码 3‑2的优化结果如下:

优化求解斐波那契数列的递归程序

def get_rabbits(months, d=None):
    if months <= 1:        # 递归边界
        return 1
    if d is None:
        d = {}              # 创建一个空字典
    elif months in d:       # 如果是重复递归
        return d[months]
    result = get_rabbits(months - 1, d) + get_rabbits(months - 2, d)
    d[months] = result
    return result

if __name__ == '__main__':
    for months in range(501):
        print(months, ':%6d' % get_rabbits(months), 'pairs')

其中参数d=None表示参数d的缺省值是None,这样在调用这个函数时,可以不为d提供实参,此时形参d的值就是None。None是Python的关键字,表示空指针的意思,相当于C、C++和Java中的null。因为Python在处理任何非基本类型数据时,使用的都是数据的地址,所以,当一个参数或者变量不指向任何对象时,可以赋予它一个None值。

可能有读者会觉得奇怪,既然如此,那么为什么不把d的缺省值定为空字典(即{})?这样还能避免像上述代码的第一行那样对等于None的d进行处理。这是因为字典是非基本类型,你看到的是空字典{},实际是指向空字典的指针,这意味着所有调用这个函数的地方,只要参数d缺省,对应的实参实际指向的是同一个空字典。这样两次不同调用之间就会产生干扰。

注意,虽然d是可以缺省的,但在递归假设时不要省略它,见代码的倒数第6行。初学者很容易在这里犯错。

判断一个键是否在字典中存在用in操作,见代码第6行if month in d。这个操作不仅对字典,对所有其他序列类型如列表(list)、元组(tuple)、集合(set)也管用。

经过上述优化,程序不但能计算50个月的兔子数,500个月的也能飞快算出。如果不优化,假设F(1)和F(0)分别只需1毫秒进行计算,则第50个月的兔子数需要146天才能计算出来,100个月的需要112亿年才能算出。

下面代码是优化人字形铁路问题的结果:

优化的人字形铁路问题递归程序

def get_trains(n, m=0, d=None):
    if n == 0:
        return 1
    if d is None:
        d = {}
    t = (n, m)
    if t in d:
        return d[t]
    result = get_trains(n-1, m+1, d)
    if m > 0:
        result += get_trains(n, m-1, d)
    d[t] = result
    return result

if __name__ == '__main__':
    for n in range(1, 100+1):
        print(n, get_trains(n))

读者可以分别用老程序和新程序计算100列火车的排列数,可以体验到截然不同的效果。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

方林博士

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值