Continuations
Long back i read about a blog post explaining a beautiful use-case of call/cc
. It tried to solve on set of variables very elegantly, in a declarative style. I tried doing something similar in scala with Cont
monad. The result looked like this
1
2
3
4
5
6
7
8
9
for(
a <- choose(1,2,3,4,5);
b <- choose(1,2,3,4,5);
c <- choose(1,2,3,4,5)
) yield {
check((a*a + b*b) == c*c)
(a,b,c)
}).get
The above code snippet tries to find pythogorean triples. In fact we can replace the condition and solve any constraint satisfaction problem which can be solved with back tracking. First of all lets see what continuations are.
Continuations
Imagine a programming language which doesnt have return statements. How would functions look like?
Every function would take a function as a parameter to which it would pass the result to. The parameter function is called as Continuation. It can be thought of sort of like a callback.
In static typed languages, a continuation of A
is any function which takes (A => R)
and gives back an R
(Usually R
is bottom type, like call/cc
)
So lets see a simple example of sum of squares of 2 integers written with continuations.,
1
2
3
4
5
6
7
8
def add[R](a: Int, b: Int)(cont: Int => R) = cont(a+b)
def square[R](a: Int)(cont: Int => R) = cont(a*a)
def sumOfSquares[R](a: Int, b: Int)(cont: Int => R) = {
square(a)(asquare =>
square(b)(bsquare =>
add(asquare, bsquare)(cont))
}
The add
and square
functions looks clean. It expresses our intent clear, but the sumOfSquares
looks bit clumsy. But the structure of it says there could be a Monad
instance for the it.
Cont Monad
We would first abstract Cont
as a class and implement map
and flatMap
on it.
1
2
3
4
5
6
7
abstract class Cont[A] {
def run[R](cont: A => R): R
}
def add(a:Int, b:Int): Cont[Int] = new Cont[Int] {
def run[R](cont: Int => R) = cont(a+b)
}
This would make the code much clearer than our original reference of (A => R) => R
since the type parameter R
in case of run
method is bound only when you call it. For instance following code snippet is valid in case the add
returns Cont[Int]
instead of (Int => R) => R
1
2
3
val sumOf1And2 = add(1,2)
val asString = sumOf1And2.run(_.toString)
val asDouble = sumOf1And2.run(_.toDouble)
Now going back to our Monad instance, lets implement map
on it, signature basically looks like this
1
2
3
4
abstract class Cont[A] {
....
def map[B](fn: A => B): Cont[B]
}
What map does is, it transforms the result into something else, ie., we apply the transformation, before we call the next continuation. Its like calling a function on the return value of some function.
What about flatMap
? whose signature is this
1
2
3
4
abstract class Cont[A] {
....
def flatMap[B](fn: A => Cont[B]): Cont[B]
}
Its exactly similar to map is, except that even the transformation function follows the convention of not returning the result and instead takes a continuation. So we have our flatMap
and map
implementation like this,
1
2
3
4
5
6
7
8
9
10
11
12
abstract class Cont[A] {
def run[R](cont: A => R): R
def map[B](fn: A => B): Cont[B] = {
val self = this
new Cont[B] {def run[R](cont: B => R) = self.run(cont compose fn)}
}
def flatMap[B](fn: A => Cont[B]): Cont[B] = {
val self = this
new Cont[B] {def run[R](cont: B => R) = self.run(a => fn(a).run(cont))}
}
def get = run(identity)
}
We also have a get function in Cont
which basically turns a Cont
back into a real return value.
Control structures
Why would anyone want to give back a Cont
when its easier to return a value back? Because clearly Cont
is way more powerful. You can choose to call the cont
parameter in a run function once, or multiple times or never at all. Thats exactly what we are going to do with our original pythagorean triple example,
So the choose
function instead of returning a chosen value, gives back a Cont
. And the run function in that Cont
keeps calling the cont
parameter with each value until it finds the one which doesnt get an exception. And check
function is even simpler, it throws an exception everytime it get a false value.
Like this,
1
2
3
4
5
6
7
8
9
10
11
12
13
def choose[A](as: A*):Cont[A] = new Cont[A] {
override def run[R](cont: A => R): R = {
doChoose(cont, as)
}
def doChoose[R](cont: A => R, values: Iterable[A]):R = {
if(values.isEmpty) throw new RuntimeException("No choice is valid")
Try(cont(values.head)).getOrElse(doChoose(cont, values.tail))
}
}
def check(bool: => Boolean): Unit = {
if(!bool) throw new RuntimeException("Constraint fails")
}
Thus we keep trying for each combination of values for a
, b
and c
until we hit upon one which satisfies our check conditions.
The entire code is in here