Fenrier Lab

利用 Cont Monad 封装计算

从回调方法说起

回调是一种很有效的异步编程方法,举一个简单的例子,如果我们要执行一个数据库查询语句,通常会像下面这样


Statement stat = ...;
String sql = ...;

ResultSet result = stat.executeQuery(sql);

//后续工作

其中,后续工作可能与查询结果有关,也可能无关。如果无关的话,那么等待查询结果返回其实是不必要的,如果对性能要求比较高,那这很显然会成为瓶颈。一个比较粗糙的解决方法是把耗时操作放到另一个线程中执行


new Thread(() -> {
    ResultSet result = stat.executeQuery(sql);
}).start();

//后续代码

但像上面这样显式的创建线程对客户端很不友好,一方面是线程创建的开销,另一方面也不利于代码的维护。一种比较好的方式是由库提供者负责线程的管理,向用户屏蔽具体的细节,也就是说,把多线程代码放到具体的执行方法内部,用户不可见。这样一来,执行方法便几乎是立即返回的(因为具体的执行过程放到了另一个线程中,当前线程不阻塞),但也就意味着客户端无法感知执行结果。这时候回调方法便派上用场了,回调方法相当于用户派出去追踪结果的制导武器,一旦执行结果出现,无论在哪里,回调方法便上前去消费掉。就像下面这样


Function<ResultSet, Void> callback = res -> ...

client.query(sql, callback);

这里的 client 便是与数据库交互的客户端工具,可以用 jdbc 来实现


public void query(String sql, Function<ResultSet, Void> callback) {
    new Thread(() -> {

        Statement stat = connection.createStatement();
        ResultSet res = stat.executeQuery(sql);
        callback.apply(res);

    }).start();
}

连续传递风格 (Continuation Passing Style)

上面我们提到的传递回调函数的方法,在函数式编程领域中有一个与之对应但更深刻的概念叫 CPS,其中 continuation 就可以被看作是回调函数,它接受当前计算的结果作为参数,并在未来某个时刻运行。为了加深印象,我们做几组对比(为了配合函数式的风格,这里我们用 lambda 表达式来表示函数,也为习惯后续的 haskell 代码做准备)

val plus10: Int => Int = a => a + 10
val result = plus10(10)
println(result)

这是普通函数,它的作用是对输入值加 10 后返回,然后打印出结果,很简单。接下来我们把它改造成 CPS 函数

val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val f: Int => Unit = res => println(res)
plus10CPS(10)(f)

在上面的代码中 f 就是回调函数(只不过写成了 lambda 表达式形式),也就是 continuation。plus10CPS 在 apply 10 之后得到一个类型为 (Int => Unit) => Unit 的值,这是一个相当重要的概念,我们后面还会提到,在这里可以看到它接收 continuation 作为输入,并在内部把当前计算的结果(即 a + 10)传给 continuation。下面我们来看第二个例子

val mul2: Int => Int = a => a * 2
val result = mul2(10)
println(result)

val mul2CPS: Int => (Int => Unit) => Unit = a => f => f(a * 2)
val f: Int => Unit = res => println(res)
mul2CPS(10)(f)

这个例子和前面没什么不同,只不过是把加 10 改成了乘 2,下面就是重点了,我们要把 plus10 和 mul2 这两个函数组合成一个更大的函数

val plus10ThenMul2: Int => Int = a => mul2(plus10(a))

普通函数的组合一目了然,但是 CPS 函数的组合则有点复杂,我们一步一步来,首先我们可以想到组合之后的函数的类型仍然是

val plus10ThenMul2CPS: Int => (Int => Unit) => Unit = a => f => {
    ...
}

它第一个输入(即 a)是传给 plus10CPS 的

a => f => {
    val c = plus10CPS(a)
    ...
}

这里的 c 需要接收一个 continuation ,但不是 f,因为 f 是 plus10ThenMul2CPS 的 continuation,c 需要接收的 continuation 肯定是 plus10CPS 之后的操作,不就是 mul2CPS 吗?但也不能把 mul2CPS 传给 c,因为类型不对,我们需要构造一个中间量,假设它是 f2

a => f => {
   val c = plus10CPS(a)
   val f2: Int => Unit = b => ...
   c(f2)
}

前面我们提到,c 会把当前计算的结果(即 (10 + a))作为参数传递给 continuation,于是 f2 的输入就是 (10 + a),也就是说把 (10 + a) 的值绑定到了 b 上。下一步自然是把 b 传给 mul2CPS

a => f => {
   val c = plus10CPS(a)
   val f2: Int => Unit = b => {
     c2 = mul2CPS(b)
     ...
   }
   c(f2)
}

mul2CPS在接收了 b 之后返回的 c2 仍然需要接收一个 continuation,并把当前计算结果(此时是 b * 2 )传给它,因为函数组合操作已经完成了,所以这时的 continuation 就是 f。

a => f => {
   val c = plus10CPS(a)
   val f2: Int => Unit = b => {
     c2 = mul2CPS(b)
     c2(f)
   }
   c(f2)
}

简化一下

val plus10CPSThenMul2CPS: Int => (Int => Unit) => Unit = a => f => plus10CPS(a)(b => mul2CPS(b)(f))

回调嵌套问题

举一个简单例子

double a = 1.0;
double b = a + 2.0;
double c = b + 3.0;
double d = c + 4.0;
System.out.println(d);

这是原型,它是一个顺序操作,并且除了第一行之外的每一行都依赖于前一行的结果,如果把它转换成回调的版本

double a = 1.0;
add2(a, res1 -> {
    double b = a + 2;
    add3(b, res2 -> {
        double c = b + 3;
        add4(c, res3 -> {
            double d = c + 4;
            System.out.println(d);
        })
    })
});

这样的代码看起来就不那么舒服了,为了化解这种

使用 Cont 封装计算

在介绍 CPS 的时候,我们提到了一个重要的类型 (Int => Unit) => Unit ,或者更一般地 (A => R) => R,这是 一个 CPS 函数接收第一个参数后的返回类型,在 haskell 中,这个类型的名字就叫 Cont,也就是下面的 c2

val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val x: Int = ...
val c2: (Int => Unit) => Unit = plus10CPS(x)
========== //另一种表达方式
val c2: (Int => Unit) => Unit = f => f(x + 10)

从上面 c2 的第二种表达方式可以看到,Cont 是把当前计算结果 (x + 10) 传递给 continuation (即 f) 的抽象模型。为什么说它是抽象模型呢,因为 plus10CPS(x) 只完成了当前计算 (x + 10),它的 continuation 在目前还不确定,而 Cont 模型表达了未来将要接收并应用 continuation 这一个行为。

在 haskell 中 Cont 的声明如下

newtype Cont r a = Cont {runCont:: (a -> r) -> r}

这一句为 (a -> r) -> r 这个类型赋予了一个名字叫 Cont r a,现在假如有一个值 c,给出两种等价的表达

c1:: Cont r a
c2:: (a -> r) -> r

虽然两者定义上等价,但实际上,c2 可以接收参数,并且这个参数就是我们前面提到的 continuation,而 c1 则不行。为了让 c1 也能计算,需要使用 runCont 把真正的类型 (a -> r) -> r 暴露出来,runCont 的类型正是

runCont:: Cont r a -> (a -> r) -> r

所以,Cont 更像是对 (a -> r) -> r 类型的一种封装,基于这样的认识,我们可以在 scala 中这样定义

class Cont[R, A](c2: (A => R) => R) {
  
  def runCont(f: A => R) = c2(f) 

}

此时,runCont 的参数是 continuation。假如有一个 Cont 类型的实例 c1,那么

c1.runCont

才是和 c2 等价的,因为它们都接收一个 continuation 作为参数。

既然类型 Cont[R, A] 是 (A => R) => R 的封装,那么前面我们提到的 CPS 函数能否转换成返回 Cont 的形式呢

val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val plus10Cont: Int => Cont[Unit, Int] = ?

答案是显然的,只需要用 plus10CPS 传入第一个参数的返回值构造 Cont 对象就行了

val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)

val x: Int = ...
val c2 = plus10CPS(x)
// val c2 = f => f(x + 10) // 另一种表示方法
val plus10Cont: Int => Cont[Unit, Int] = a => new Cont(c2)
  plus10CPS plus10Cont
输入类型 Int Int
返回类型 (Int => Unit) => Unit Cont[Unit ,Int]

上表是 plus10CPS 和 plus10Cont 的简单对比。mul2Cont 也可以如法炮制

val mul2Cont: Int => Cont[Unit, Int] = a => new Cont(f => f(a * 2))

反过来, Cont 形式的函数也可以得到 CPS 函数,我们仍然一步一步推导,首先可以写出 plus10CPSFromCont 的大致形式

plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
  ...
}

这里的 a 是 CPS 函数的第一个参数,并且恰好也是 plus10Cont 的第一个参数,所以

plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
  val c1: Cont[Unit, Int] = plus10Cont(a)
  ...
}

而 f 是 continuation,又刚好是 c1.runCont 的参数,于是

plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
  val c1: Cont[Unit, Int] = plus10Cont(a)
  c1.runCont(f)
}

简化一下

plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => plus10CPS(a).runCont(f)

接下来又来到了喜闻乐见的函数组合环节,我们需要把两个 Cont 函数组合成一个更大的函数

val plus10ThenMul2Cont: Int => Cont[Unit, Int] = ...

有了 CPS 函数组合的经验,我们只需要在此基础上略微改动,首先,由于返回值类型是 Cont,所以最终结果应该被用来新建 Cont 对象

val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a => 
    new Cont(
        f => plus10CPS(a)(b => mul2CPS(b)(f))
    )

然后,plus10CPS 应该从 plus10Cont 得到,mul2CPS 应该从 mul2Cont 得到,这我们刚好学过,于是组合后的函数就为

val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a => 
    new Cont(
        f => plus10Cont(a).runCont(b => mul2Cont(b).runCont(f))
    )

到这里,我们费劲心思,终于把两个返回 Cont 的函数组合成了一个更大的函数,其中意义又在哪里呢?当然不是为了花式的表达

\x -> 2 * (10 + x)

而是为了实现下面要介绍的 Cont Monad 的 bind (也就是 »=) 函数

实现 Cont Monad

类似于 Maybe, [] 这些类型,Cont 也是一个 Monad,在 haskell 中,它的声明如下

instance Monad (Cont r) where 
    return :: a -> Cont r a
    return x = ...

    >>= :: m a -> (a -> m b) -> m b
    m >>= k = ...

其中 return 很好实现,它相当于当前计算为自身,也就是

return x = Cont \f -> f x
-- 对比 plus10
plus10Cont x = Cont \f -> f $ x + 10

由于 return 是 scala 语言的关键字,所以在 scala 中我们用 `return` 代替

def `return`[R, A](x: A): Cont[R, A] = new Cont(f => f(x))

// 当然也可以用 lambda 表达式定义
val `return`: A => Cont[R, A] = x => new Cont(f => f(x))

在 scala 中, bind (»=) 函数的类型为

def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
    ...
}

这里的参数 k 的类型和前面我们举的例子 plus10Cont,mul2Cont 是一样的,我们令

val x = 10
val c1 = plus10Cont(x)
val c3 = c1.bind(mul2Cont)

如果让 plus10Cont 自述它的行为

当前我所要进行的计算是 b = x + 10,但对我来说 x 仍是未知数,需要调用者提供,并且完成当前计算之后,我还需要一个 f 函数,把 b 传递给它,最终才能完成计算。

那么 c1 这个 Cont 表达的语义则是

当前我所要做的计算是 f(b),b 对我来说是已知的,但是 f 还未知晓,需要调用者通过 runCont 接口提供。

如果我们调用 c1 的 runCont 方法,并且

val f = x => println(x)

就相当于告诉它

hi! c1,我这里有个 continuation,名叫 f,现在我将它传给你,需要你运行它。

于是 c1 完成了计算。但当我们调用 c1 的 bind 方法时,事情就变得稍微复杂一点了,同 plus10Cont 一样,mul2Cont 的自述是

当前我所要进行的计算是 d = y * 2,但对我来说 y 仍是未知数,需要调用者提供,并且完成当前计算之后,我还需要一个 g 函数,把 d 传递给它,最终才能完成计算。

现在的问题是 y 的值是多少?处在 c1 的计算环境中,唯一的已知量是 b,所以毫无疑问 y = b,但是这里的 b 得处于 runCont 的环境下才能获得,所以在 bind 函数内部,我们需要运行它自身的 runCont

val some = g => runCont(b => {
    val cm = mul2Cont(b)
    cm.runCont(g)
})

这里我们显然不是 »= 的返回类型,另一方面,g 也还不确定。为了得到 Cont 类型的返回值,和之前一样,我们需要显式地构造它,也就是

new Cont(h => ...)

这里的 h 是完成当前计算后需要传给 runCont 的参数,而 g 正是传给 cm 的 runCont 的,所以何不让 h = g,从而完成 bind 函数的实现

def bind(k: A => Cont[R, B]): Cont[R, B] = {
    new Cont(g => runCont(b => {
        val cm = k(b)
        cm.runCont(g)
    }))
}

简化一下

def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
    new Cont(g => runCont(b => k(b).runCont(g)))
}

如果我们把前面讨论的两个 Cont 函数的组合方法拿过来和 »= 进行对比

val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a => 
    new Cont(
        f => plus10Cont(a).runCont(b => mul2Cont(b).runCont(f))
    )

上面的 plus10Cont(a) 是一个 Cont,在 Cont 内部的可以直接调用 runCont,如果把 plus10Cont(a) 省略掉再看

new Cont(
        f => runCont(b => mul2Cont(b).runCont(f))
    )

可以发现,这不就是 »= 方法的函数体吗?所以看起来晦涩的 »= 方法实际上表达的是两个连续操作的组合,只不过是在其中一个 Cont 内进行的,返回的也是 Cont。这样理解的话,haskell 中 Cont 的 »= 也就一目了然的

m >>= k = Cont $ \f -> runCont m (\a -> runCont k f)

下面我们给出 scala 版本的 Cont 实现

class Cont[R, A](cont: (A => R) => R){

  def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
    Cont(f => runCont(b => k(b).runCont(f)))
  }

  def runCont(f: A => R): R = cont(f)
}
    
object Cont {

  def `return`[R, A](a: A): Cont[R, A] = {
    Cont((f: A => R)=> f(a))
  }

  def apply[R, A](cont: (A => R) => R): Cont[R, A] = new Cont(cont)

}

简单使用一下

Cont.`return`[Unit, Int](10)
      .>>=(res => Cont.`return`(res * 2))
      .>>=(res => Cont.`return`(res / 3.0))
      .runCont(a => println(a))

很棒!这一节,我们介绍了用 Cont 封装计算的方法,在形式上,实现了把连续操作组合成一个更大操作的流式调用风格。但是如果中间操作抛出异常的话,整个过程就中断了,为了捕捉异常,我们需要在语句外加上 try catch 语法,这就属于干脏活了。为了更优雅的处理中间过程发生的异常,下面我们开始介绍 callcc。

使用 callcc 进行异常处理

总结

本文遵守 CC-BY-NC-4.0 许可协议。

Creative Commons License

欢迎转载,转载需注明出处,且禁止用于商业目的。