Continuation Passing Style和Trampoline
前言
以下是我读Essentials Of Programming Languages时候的理解,可能有不对的地方,请谨慎参考。
代码均为Scala 3。
Continuation
这个的翻译大概是后续部分、延续。
例如对于1 + (2 + 3)
,如果我们想要计算1 + _
,那我们就要计算(2 + 3)
。那么(2 + 3)
的延续就是1 + _
。
所以我们就将这个1 + _
作为环境传下去。
也就是说,后续的函数会拿到两个参数,一个是(2 + 3)
这个表达式,另一个则是1 + _
这个延续,告诉它计算了表达式之后该做的事。
这种形式就被称为Continuation Passing Style(以下简称CPS)。
我们以加法为例,用Scala来实现一下。
一个例子
我们以加法为例。我们定义一个简单的加法的抽象语法树:
enum Exp {
case IntExp(v: Int)
case SumExp(e1: Exp, e2: Exp)
}
在一般情况下,我们会简单的通过不断的递归来计算表达式的值:
def simpleParse(exp: Exp): Int = exp match {
case Exp.IntExp(v) => v
case Exp.SumExp(e1, e2) => simpleParse(e1) + simpleParse(e2)
}
我们会发现,它并不是一个尾递归的形式,因此有栈溢出的风险。
@main
def main() = {
import Exp._
val exps = (1 until 10000).map(IntExp(_)).foldRight(IntExp(0))(SumExp(_, _))
println(simpleParse(exps))
}
在我的电脑上,大概10000次递归就会栈溢出。
延续传递模式
我们将它转为CPS。
我们首先定义会有哪些延续:
enum Continuation {
case EndK // 全部算完了
case Sum1K(e2: Exp, cont: Continuation) // 正在计算加号左侧表达式
case Sum2K(v1: Int, cont: Continuation) // 正在计算加号右侧表达式
}
之后我们定义对表达式的计算:
def applyContext(expVal: Int)(cont: Continuation): Int = cont match {
case Continuation.EndK => expVal
case Continuation.Sum1K(e2, cont) => cpsParse(e2)(Continuation.Sum2K(expVal, cont))
case Continuation.Sum2K(v1, cont) => applyContext(v1 + expVal)(cont)
}
def cpsParse(exp: Exp)(cont: Continuation): Int = exp match {
case Exp.IntExp(v) => applyContext(v)(cont)
case Exp.SumExp(e1, e2) => cpsParse(e1)(Continuation.Sum1K(e2, cont))
}
对于A + B
以及延续K
而言,首先我们计算A
,将_ + B
这个延续传进去。计算完A
得到a
之后,我们计算B
,并将a + _
这个延续传下去。
计算完B
得到b
之后,我们计算a + b
,并应用到K
之中。
我们可以看到在应用了CPS之后,就呈现一个尾调用的状态。但此时如果我们执行:
@main
def main = {
import Exp._
val exps = (1 until 10000).map(IntExp(_)).foldRight(IntExp(0))(SumExp(_, _))
println(cpsParse(exps)(Continuation.EndK))
}
依然会栈溢出。
Trampoline
这时候我们就需要将尾调用转换为尾递归,用到一个蹦床。
蹦床是一个尾递归函数:
enum Bounce {
case ExpVal(v: Int)
case B(f: () => Bounce)
}
def trampoline(bounce: Bounce): Int = bounce match {
case Bounce.ExpVal(v) => v
case Bounce.B(b) => trampoline(b())
}
可以看到,如果蹦床函数的参数是一个函数,它就会对它进行求值,并把结果放进另一个蹦床函数中进行求值。 而因为它是尾递归的,因此并不会有栈溢出现象。 体现在堆栈上就是推进去一个trampoline、再弹出来、再推进去一个trampoline、再弹出来,上上下下像是一个蹦床,也就是名字的由来。
我们所要做的就是把之前直接互相尾调用的函数改为间接尾调用,把想要的结果用函数包裹,而不是直接求值。 这样每一步都会由蹦床函数来执行。
修改之后便是这样:
def applyContext(expVal: Int)(cont: Continuation): Bounce = cont match {
case Continuation.EndK => Bounce.ExpVal(expVal)
case Continuation.Sum1K(e2, cont) => Bounce.B(() => cpsParse(e2)(Continuation.Sum2K(expVal, cont)))
case Continuation.Sum2K(v1, cont) => Bounce.B(() => applyContext(v1 + expVal)(cont))
}
def cpsParse(exp: Exp)(cont: Continuation): Bounce = exp match {
case Exp.IntExp(v) => Bounce.B(() => applyContext(v)(cont))
case Exp.SumExp(e1, e2) => Bounce.B(() => cpsParse(e1)(Continuation.Sum1K(e2, cont)))
}
然后运算:
@main
def main = {
import Exp._
val exps = (1 until 100000).map(IntExp(_)).foldRight(IntExp(0))(SumExp(_, _))
println(trampoline(cpsParse(exps)(Continuation.EndK)))
}
这里数字随便填,栈溢出算我输.jpg。
当然有人可能会问,那么Java没有尾递归优化,怎么办呢?当然是用迭代循环啦。
总结
这里大致介绍了下我个人对于continuation和trampoline的理解。
Essential Of Programming Languages是一本不错的书,向大家推荐。