## Continuations

Long back i read about a blog post explaining a beautiful use-case of `call/cc`

. It tried to solve on set of variables very elegantly, in a declarative style. I tried doing something similar in scala with `Cont`

monad. The result looked like this

1
2
3
4
5
6
7
8
9

for(
a <- choose(1,2,3,4,5);
b <- choose(1,2,3,4,5);
c <- choose(1,2,3,4,5)
) yield {
check((a*a + b*b) == c*c)
(a,b,c)
}).get

The above code snippet tries to find pythogorean triples. In fact we can replace the condition and solve any constraint satisfaction problem which can be solved with back tracking. First of all lets see what continuations are.

## Continuations

Imagine a programming language which doesnt have return statements. How would functions look like?

Every function would take a function as a parameter to which it would pass the result to. The parameter function is called as Continuation. It can be thought of sort of like a callback.

In static typed languages, a continuation of `A`

is any function which takes `(A => R)`

and gives back an `R`

(Usually `R`

is bottom type, like `call/cc`

)

So lets see a simple example of sum of squares of 2 integers written with continuations.,

1
2
3
4
5
6
7
8

def add[R](a: Int, b: Int)(cont: Int => R) = cont(a+b)
def square[R](a: Int)(cont: Int => R) = cont(a*a)
def sumOfSquares[R](a: Int, b: Int)(cont: Int => R) = {
square(a)(asquare =>
square(b)(bsquare =>
add(asquare, bsquare)(cont))
}

The `add`

and `square`

functions looks clean. It expresses our intent clear, but the `sumOfSquares`

looks bit clumsy. But the structure of it says there could be a `Monad`

instance for the it.

## Cont Monad

We would first abstract `Cont`

as a class and implement `map`

and `flatMap`

on it.

1
2
3
4
5
6
7

abstract class Cont[A] {
def run[R](cont: A => R): R
}
def add(a:Int, b:Int): Cont[Int] = new Cont[Int] {
def run[R](cont: Int => R) = cont(a+b)
}

This would make the code much clearer than our original reference of `(A => R) => R`

since the type parameter `R`

in case of `run`

method is bound only when you call it. For instance following code snippet is valid in case the `add`

returns `Cont[Int]`

instead of `(Int => R) => R`

1
2
3

val sumOf1And2 = add(1,2)
val asString = sumOf1And2.run(_.toString)
val asDouble = sumOf1And2.run(_.toDouble)

Now going back to our Monad instance, lets implement `map`

on it, signature basically looks like this

1
2
3
4

abstract class Cont[A] {
....
def map[B](fn: A => B): Cont[B]
}

What map does is, it transforms the result into something else, ie., we apply the transformation, before we call the next continuation. Its like calling a function on the return value of some function.

What about `flatMap`

? whose signature is this

1
2
3
4

abstract class Cont[A] {
....
def flatMap[B](fn: A => Cont[B]): Cont[B]
}

Its exactly similar to map is, except that even the transformation function follows the convention of not returning the result and instead takes a continuation. So we have our `flatMap`

and `map`

implementation like this,

1
2
3
4
5
6
7
8
9
10
11
12

abstract class Cont[A] {
def run[R](cont: A => R): R
def map[B](fn: A => B): Cont[B] = {
val self = this
new Cont[B] {def run[R](cont: B => R) = self.run(cont compose fn)}
}
def flatMap[B](fn: A => Cont[B]): Cont[B] = {
val self = this
new Cont[B] {def run[R](cont: B => R) = self.run(a => fn(a).run(cont))}
}
def get = run(identity)
}

We also have a get function in `Cont`

which basically turns a `Cont`

back into a real return value.

## Control structures

Why would anyone want to give back a `Cont`

when its easier to return a value back? Because clearly `Cont`

is way more powerful. You can choose to call the `cont`

parameter in a run function once, or multiple times or never at all. Thats exactly what we are going to do with our original pythagorean triple example,

So the `choose`

function instead of returning a chosen value, gives back a `Cont`

. And the run function in that `Cont`

keeps calling the `cont`

parameter with each value until it finds the one which doesnt get an exception. And `check`

function is even simpler, it throws an exception everytime it get a false value.

Like this,

1
2
3
4
5
6
7
8
9
10
11
12
13

def choose[A](as: A*):Cont[A] = new Cont[A] {
override def run[R](cont: A => R): R = {
doChoose(cont, as)
}
def doChoose[R](cont: A => R, values: Iterable[A]):R = {
if(values.isEmpty) throw new RuntimeException("No choice is valid")
Try(cont(values.head)).getOrElse(doChoose(cont, values.tail))
}
}
def check(bool: => Boolean): Unit = {
if(!bool) throw new RuntimeException("Constraint fails")
}

Thus we keep trying for each combination of values for `a`

, `b`

and `c`

until we hit upon one which satisfies our check conditions.

The entire code is in here