利用 Cont Monad 封装计算
从回调方法说起
回调是一种很有效的异步编程方法,举一个简单的例子,如果我们要执行一个数据库查询语句,通常会像下面这样
Statement stat = ...;
String sql = ...;
ResultSet result = stat.executeQuery(sql);
//后续工作
其中,后续工作可能与查询结果有关,也可能无关。如果无关的话,那么等待查询结果返回其实是不必要的,如果对性能要求比较高,那这很显然会成为瓶颈。一个比较粗糙的解决方法是把耗时操作放到另一个线程中执行
new Thread(() -> {
ResultSet result = stat.executeQuery(sql);
}).start();
//后续代码
但像上面这样显式的创建线程对客户端很不友好,一方面是线程创建的开销,另一方面也不利于代码的维护。一种比较好的方式是由库提供者负责线程的管理,向用户屏蔽具体的细节,也就是说,把多线程代码放到具体的执行方法内部,用户不可见。这样一来,执行方法便几乎是立即返回的(因为具体的执行过程放到了另一个线程中,当前线程不阻塞),但也就意味着客户端无法感知执行结果。这时候回调方法便派上用场了,回调方法相当于用户派出去追踪结果的制导武器,一旦执行结果出现,无论在哪里,回调方法便上前去消费掉。就像下面这样
Function<ResultSet, Void> callback = res -> ...
client.query(sql, callback);
这里的 client 便是与数据库交互的客户端工具,可以用 jdbc 来实现
public void query(String sql, Function<ResultSet, Void> callback) {
new Thread(() -> {
Statement stat = connection.createStatement();
ResultSet res = stat.executeQuery(sql);
callback.apply(res);
}).start();
}
连续传递风格 (Continuation Passing Style)
上面我们提到的传递回调函数的方法,在函数式编程领域中有一个与之对应但更深刻的概念叫 CPS,其中 continuation 就可以被看作是回调函数,它接受当前计算的结果作为参数,并在未来某个时刻运行。为了加深印象,我们做几组对比(为了配合函数式的风格,这里我们用 lambda 表达式来表示函数,也为习惯后续的 haskell 代码做准备)
val plus10: Int => Int = a => a + 10
val result = plus10(10)
println(result)
这是普通函数,它的作用是对输入值加 10 后返回,然后打印出结果,很简单。接下来我们把它改造成 CPS 函数
val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val f: Int => Unit = res => println(res)
plus10CPS(10)(f)
在上面的代码中 f 就是回调函数(只不过写成了 lambda 表达式形式),也就是 continuation。plus10CPS 在 apply 10 之后得到一个类型为 (Int => Unit) => Unit 的值,这是一个相当重要的概念,我们后面还会提到,在这里可以看到它接收 continuation 作为输入,并在内部把当前计算的结果(即 a + 10)传给 continuation。下面我们来看第二个例子
val mul2: Int => Int = a => a * 2
val result = mul2(10)
println(result)
val mul2CPS: Int => (Int => Unit) => Unit = a => f => f(a * 2)
val f: Int => Unit = res => println(res)
mul2CPS(10)(f)
这个例子和前面没什么不同,只不过是把加 10 改成了乘 2,下面就是重点了,我们要把 plus10 和 mul2 这两个函数组合成一个更大的函数
val plus10ThenMul2: Int => Int = a => mul2(plus10(a))
普通函数的组合一目了然,但是 CPS 函数的组合则有点复杂,我们一步一步来,首先我们可以想到组合之后的函数的类型仍然是
val plus10ThenMul2CPS: Int => (Int => Unit) => Unit = a => f => {
...
}
它第一个输入(即 a)是传给 plus10CPS 的
a => f => {
val c = plus10CPS(a)
...
}
这里的 c 需要接收一个 continuation ,但不是 f,因为 f 是 plus10ThenMul2CPS 的 continuation,c 需要接收的 continuation 肯定是 plus10CPS 之后的操作,不就是 mul2CPS 吗?但也不能把 mul2CPS 传给 c,因为类型不对,我们需要构造一个中间量,假设它是 f2
a => f => {
val c = plus10CPS(a)
val f2: Int => Unit = b => ...
c(f2)
}
前面我们提到,c 会把当前计算的结果(即 (10 + a))作为参数传递给 continuation,于是 f2 的输入就是 (10 + a),也就是说把 (10 + a) 的值绑定到了 b 上。下一步自然是把 b 传给 mul2CPS
a => f => {
val c = plus10CPS(a)
val f2: Int => Unit = b => {
c2 = mul2CPS(b)
...
}
c(f2)
}
mul2CPS在接收了 b 之后返回的 c2 仍然需要接收一个 continuation,并把当前计算结果(此时是 b * 2 )传给它,因为函数组合操作已经完成了,所以这时的 continuation 就是 f。
a => f => {
val c = plus10CPS(a)
val f2: Int => Unit = b => {
c2 = mul2CPS(b)
c2(f)
}
c(f2)
}
简化一下
val plus10CPSThenMul2CPS: Int => (Int => Unit) => Unit = a => f => plus10CPS(a)(b => mul2CPS(b)(f))
回调嵌套问题
举一个简单例子
double a = 1.0;
double b = a + 2.0;
double c = b + 3.0;
double d = c + 4.0;
System.out.println(d);
这是原型,它是一个顺序操作,并且除了第一行之外的每一行都依赖于前一行的结果,如果把它转换成回调的版本
double a = 1.0;
add2(a, res1 -> {
double b = a + 2;
add3(b, res2 -> {
double c = b + 3;
add4(c, res3 -> {
double d = c + 4;
System.out.println(d);
})
})
});
这样的代码看起来就不那么舒服了,为了化解这种
使用 Cont 封装计算
在介绍 CPS 的时候,我们提到了一个重要的类型 (Int => Unit) => Unit ,或者更一般地 (A => R) => R,这是 一个 CPS 函数接收第一个参数后的返回类型,在 haskell 中,这个类型的名字就叫 Cont,也就是下面的 c2
val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val x: Int = ...
val c2: (Int => Unit) => Unit = plus10CPS(x)
========== //另一种表达方式
val c2: (Int => Unit) => Unit = f => f(x + 10)
从上面 c2 的第二种表达方式可以看到,Cont 是把当前计算结果 (x + 10) 传递给 continuation (即 f) 的抽象模型。为什么说它是抽象模型呢,因为 plus10CPS(x) 只完成了当前计算 (x + 10),它的 continuation 在目前还不确定,而 Cont 模型表达了未来将要接收并应用 continuation 这一个行为。
在 haskell 中 Cont 的声明如下
newtype Cont r a = Cont {runCont:: (a -> r) -> r}
这一句为 (a -> r) -> r 这个类型赋予了一个名字叫 Cont r a,现在假如有一个值 c,给出两种等价的表达
c1:: Cont r a
c2:: (a -> r) -> r
虽然两者定义上等价,但实际上,c2 可以接收参数,并且这个参数就是我们前面提到的 continuation,而 c1 则不行。为了让 c1 也能计算,需要使用 runCont 把真正的类型 (a -> r) -> r 暴露出来,runCont 的类型正是
runCont:: Cont r a -> (a -> r) -> r
所以,Cont 更像是对 (a -> r) -> r 类型的一种封装,基于这样的认识,我们可以在 scala 中这样定义
class Cont[R, A](c2: (A => R) => R) {
def runCont(f: A => R) = c2(f)
}
此时,runCont 的参数是 continuation。假如有一个 Cont 类型的实例 c1,那么
c1.runCont
才是和 c2 等价的,因为它们都接收一个 continuation 作为参数。
既然类型 Cont[R, A] 是 (A => R) => R 的封装,那么前面我们提到的 CPS 函数能否转换成返回 Cont 的形式呢
val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val plus10Cont: Int => Cont[Unit, Int] = ?
答案是显然的,只需要用 plus10CPS 传入第一个参数的返回值构造 Cont 对象就行了
val plus10CPS: Int => (Int => Unit) => Unit = a => f => f(a + 10)
val x: Int = ...
val c2 = plus10CPS(x)
// val c2 = f => f(x + 10) // 另一种表示方法
val plus10Cont: Int => Cont[Unit, Int] = a => new Cont(c2)
plus10CPS | plus10Cont | |
---|---|---|
输入类型 | Int | Int |
返回类型 | (Int => Unit) => Unit | Cont[Unit ,Int] |
上表是 plus10CPS 和 plus10Cont 的简单对比。mul2Cont 也可以如法炮制
val mul2Cont: Int => Cont[Unit, Int] = a => new Cont(f => f(a * 2))
反过来, Cont 形式的函数也可以得到 CPS 函数,我们仍然一步一步推导,首先可以写出 plus10CPSFromCont 的大致形式
plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
...
}
这里的 a 是 CPS 函数的第一个参数,并且恰好也是 plus10Cont 的第一个参数,所以
plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
val c1: Cont[Unit, Int] = plus10Cont(a)
...
}
而 f 是 continuation,又刚好是 c1.runCont 的参数,于是
plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => {
val c1: Cont[Unit, Int] = plus10Cont(a)
c1.runCont(f)
}
简化一下
plus10CPSFromCont: Int => (Int => Unit) => Unit = a => f => plus10CPS(a).runCont(f)
接下来又来到了喜闻乐见的函数组合环节,我们需要把两个 Cont 函数组合成一个更大的函数
val plus10ThenMul2Cont: Int => Cont[Unit, Int] = ...
有了 CPS 函数组合的经验,我们只需要在此基础上略微改动,首先,由于返回值类型是 Cont,所以最终结果应该被用来新建 Cont 对象
val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a =>
new Cont(
f => plus10CPS(a)(b => mul2CPS(b)(f))
)
然后,plus10CPS 应该从 plus10Cont 得到,mul2CPS 应该从 mul2Cont 得到,这我们刚好学过,于是组合后的函数就为
val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a =>
new Cont(
f => plus10Cont(a).runCont(b => mul2Cont(b).runCont(f))
)
到这里,我们费劲心思,终于把两个返回 Cont 的函数组合成了一个更大的函数,其中意义又在哪里呢?当然不是为了花式的表达
\x -> 2 * (10 + x)
而是为了实现下面要介绍的 Cont Monad 的 bind (也就是 »=) 函数
实现 Cont Monad
类似于 Maybe, [] 这些类型,Cont 也是一个 Monad,在 haskell 中,它的声明如下
instance Monad (Cont r) where
return :: a -> Cont r a
return x = ...
>>= :: m a -> (a -> m b) -> m b
m >>= k = ...
其中 return 很好实现,它相当于当前计算为自身,也就是
return x = Cont \f -> f x
-- 对比 plus10
plus10Cont x = Cont \f -> f $ x + 10
由于 return 是 scala 语言的关键字,所以在 scala 中我们用 `return` 代替
def `return`[R, A](x: A): Cont[R, A] = new Cont(f => f(x))
// 当然也可以用 lambda 表达式定义
val `return`: A => Cont[R, A] = x => new Cont(f => f(x))
在 scala 中, bind (»=) 函数的类型为
def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
...
}
这里的参数 k 的类型和前面我们举的例子 plus10Cont,mul2Cont 是一样的,我们令
val x = 10
val c1 = plus10Cont(x)
val c3 = c1.bind(mul2Cont)
如果让 plus10Cont 自述它的行为
当前我所要进行的计算是 b = x + 10,但对我来说 x 仍是未知数,需要调用者提供,并且完成当前计算之后,我还需要一个 f 函数,把 b 传递给它,最终才能完成计算。
那么 c1 这个 Cont 表达的语义则是
当前我所要做的计算是 f(b),b 对我来说是已知的,但是 f 还未知晓,需要调用者通过 runCont 接口提供。
如果我们调用 c1 的 runCont 方法,并且
val f = x => println(x)
就相当于告诉它
hi! c1,我这里有个 continuation,名叫 f,现在我将它传给你,需要你运行它。
于是 c1 完成了计算。但当我们调用 c1 的 bind 方法时,事情就变得稍微复杂一点了,同 plus10Cont 一样,mul2Cont 的自述是
当前我所要进行的计算是 d = y * 2,但对我来说 y 仍是未知数,需要调用者提供,并且完成当前计算之后,我还需要一个 g 函数,把 d 传递给它,最终才能完成计算。
现在的问题是 y 的值是多少?处在 c1 的计算环境中,唯一的已知量是 b,所以毫无疑问 y = b,但是这里的 b 得处于 runCont 的环境下才能获得,所以在 bind 函数内部,我们需要运行它自身的 runCont
val some = g => runCont(b => {
val cm = mul2Cont(b)
cm.runCont(g)
})
这里我们显然不是 »= 的返回类型,另一方面,g 也还不确定。为了得到 Cont 类型的返回值,和之前一样,我们需要显式地构造它,也就是
new Cont(h => ...)
这里的 h 是完成当前计算后需要传给 runCont 的参数,而 g 正是传给 cm 的 runCont 的,所以何不让 h = g,从而完成 bind 函数的实现
def bind(k: A => Cont[R, B]): Cont[R, B] = {
new Cont(g => runCont(b => {
val cm = k(b)
cm.runCont(g)
}))
}
简化一下
def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
new Cont(g => runCont(b => k(b).runCont(g)))
}
如果我们把前面讨论的两个 Cont 函数的组合方法拿过来和 »= 进行对比
val plus10ThenMul2Cont: Int => Cont[Unit, Int] = a =>
new Cont(
f => plus10Cont(a).runCont(b => mul2Cont(b).runCont(f))
)
上面的 plus10Cont(a) 是一个 Cont,在 Cont 内部的可以直接调用 runCont,如果把 plus10Cont(a) 省略掉再看
new Cont(
f => runCont(b => mul2Cont(b).runCont(f))
)
可以发现,这不就是 »= 方法的函数体吗?所以看起来晦涩的 »= 方法实际上表达的是两个连续操作的组合,只不过是在其中一个 Cont 内进行的,返回的也是 Cont。这样理解的话,haskell 中 Cont 的 »= 也就一目了然的
m >>= k = Cont $ \f -> runCont m (\a -> runCont k f)
下面我们给出 scala 版本的 Cont 实现
class Cont[R, A](cont: (A => R) => R){
def >>=[B](k: A => Cont[R, B]): Cont[R, B] = {
Cont(f => runCont(b => k(b).runCont(f)))
}
def runCont(f: A => R): R = cont(f)
}
object Cont {
def `return`[R, A](a: A): Cont[R, A] = {
Cont((f: A => R)=> f(a))
}
def apply[R, A](cont: (A => R) => R): Cont[R, A] = new Cont(cont)
}
简单使用一下
Cont.`return`[Unit, Int](10)
.>>=(res => Cont.`return`(res * 2))
.>>=(res => Cont.`return`(res / 3.0))
.runCont(a => println(a))
很棒!这一节,我们介绍了用 Cont 封装计算的方法,在形式上,实现了把连续操作组合成一个更大操作的流式调用风格。但是如果中间操作抛出异常的话,整个过程就中断了,为了捕捉异常,我们需要在语句外加上 try catch 语法,这就属于干脏活了。为了更优雅的处理中间过程发生的异常,下面我们开始介绍 callcc。