实现尾递归优化
// 定义一个尾递归优化后的阶乘函数
def factorial(n: Int): Int = {
// 定义一个辅助函数,接受两个参数:当前值和累积结果
def loop(x: Int, acc: Int): Int = {
// 如果当前值等于1,返回累积结果
if (x == 1) acc
// 否则调用自身,更新当前值和累积结果
else loop(x - 1, x * acc)
}
// 调用辅助函数,传入初始值和1
loop(n, 1)
}
// 调用阶乘函数
println(factorial(5)) // 输出120
为什么要进行尾递归优化?
为什么要进行尾递归优化,是因为尾递归可以减少调用栈的占用,从而避免栈溢出的风险,提高性能和内存利用率。结合代码来详解一下:
-
没有优化的递归函数
// 定义一个阶乘函数 def factorial(n: Int): Int = { // 如果n等于1,返回1 if (n == 1) 1 // 否则返回n乘以n-1的阶乘 else n * factorial(n - 1) } // 调用阶乘函数 println(factorial(5)) // 输出120
这个函数在计算阶乘的过程中,
会产生多个调用栈
,每次调用自身都会保存当前的参数和返回位置,等待下一次调用返回结果。例如,当我们计算factorial(5)时,会产生如下的调用栈:factorial(5) -> n * factorial(4) factorial(4) -> n * factorial(3) factorial(3) -> n * factorial(2) factorial(2) -> n * factorial(1) factorial(1) -> 1
当factorial(1)返回1时,才开始从栈顶到栈底依次计算结果,最后返回120。这样做的缺点是,
如果n很大,会产生很多的调用栈,占用很多内存空间,甚至可能导致栈溢出
。
-
优化后的尾递归函数
// 定义一个尾递归优化后的阶乘函数 def factorial(n: Int): Int = { // 定义一个辅助函数,接受两个参数:当前值和累积结果 def loop(x: Int, acc: Int): Int = { // 如果当前值等于1,返回累积结果 if (x == 1) acc // 否则调用自身,更新当前值和累积结果 else loop(x - 1, x * acc) } // 调用辅助函数,传入初始值和1 loop(n, 1) } // 调用阶乘函数 println(factorial(5)) // 输出120
这个函数在计算阶乘的过程中,
只会产生一个调用栈
,每次调用自身都不会保存当前的参数和返回位置,而是
直接替换成下一次调用的参数和返回位置`。例如,当我们计算factorial(5)时,只会产生如下的调用栈:loop(5, 1) -> loop(4, 5) -> loop(3, 20) -> loop(2, 60) -> loop(1, 120) -> 120
当loop(1, 120)返回120时,就是最终的结果,不需要再从栈顶到栈底依次计算结果。这样做的优点是,无论n多大,都只会产生一个调用栈,节省了内存空间,也避免了栈溢出。