[Kotlin Pearls 8] Recursion, Tail Recursion and Y Combinator

Understanding how tailrec keyword works and how to take advantage of recursion in Kotlin

Uberto Barbini
ProAndroidDev

--

Image credits: https://pixabay.com/illustrations/fractal-abstract-red-white-stripes-542155/

It is already passed a year since I wrote the first Kotlin Pearl blog post. Thanks to everybody (and you are really a lot) who supported me and clapped or liked my posts.

This post will cover:

  • How to solve a problem using recursion
  • How to transform Head Recursion into Tail Recursion
  • The tailrec keyword and how it’s translated in ByteCode
  • How to implement the Y-Combinator in Kotlin

Recursion is a powerful technique, it allows us to define a complex problem in terms of solving a simpler version.

The main characteristic of a recursive function is that at some point it calls itself. It can be quite difficult to grasp at the beginning, but as the saying goes:

To understand recursion,

you must first understand recursion

Let’s do a simple example of how to use the power of recursion on a simple problem. You may not know how to sum 100 numbers, but you certainly know how to sum 2 numbers.

What about 3 numbers? Well, if you already know the sum of 2 of them, then you can sum it to the third number. So know you know how to sum 3 numbers.

What about 4 numbers? Well, since you already know how to sum 3 numbers… I think you can see where we are going to.

So the recursive algorithm of this problem in pseudo-code is

if size the list of numbers == 2
return the sum of two numbers
else
return the sum of the first number and (the sum of all others)

and in Kotlin code is like this:

fun sumNumbers(nums: List<Int>): Int =
if (nums.size == 2)
nums[0] + nums[1]
else
nums[0] + sumNumbers(nums.drop(1))

Now something a little more difficult. Let’s take Euclid’s Algorithm for Greatest Common Divisor. If you have two numbers, then GCD is the biggest number that divides them both. To find it we can subtract the smaller number from the biggest until they become equal (in the worst case both to 1).

In pseudo-code the algorithm is like this:

if a = b gcd(a,b)=a,
if a > b gcd(a,b)=gcd(a-b,b)
if b > a gcd(a,b)=gcd(a,b-a)

To see how it works, let’s see how to find the GCD of 30 and 20 with this algorithm. These are the steps:

gcd(30, 20) a > b 
-> gcd(10, 20) a < b
-> gcd(10, 10) a = b
-> 10

For a complete explanation see the article on Wikipedia.

How we can implement this in Kotlin?

The non-recursive way is to use a while loop with two variables:

fun gcdNR(n1: Int, n2: Int): Int {
var
a = n1
var b = n2
while (a != b) {
if (
a > b) a = a - b
else b = b - a
}
return
a
}

We can compare it with the recursive version:

fun gcd(n1: Int, n2: Int): Int =
when {
n1 == n2 -> n1
n1 > n2 -> gcd(n1 - n2, n2)
else
-> gcd(n1, n2-n1)
}

You can see how the recursive version is not only shorter but also much easier to read.

If you don’t know the algorithm, to understand how the loop version works you need to “wind” the loop inside your mind. On the contrary, in the recursive version, it is very clear the intent of the code.

When we are looking at recursing algorithms, a useful distinction is Head Recursion and Tail Recursion.

In Head Recursion, we call ourselves first and then we do something about the result of recursion.

In Tail Recursion, the recursion is the last operation in all logical branches of the function.

To see the difference let’s write a Fibonacci numbers generator.

First the non-recursive version:

fun fibonacciNR(n: Int): Int {
var
prev1 = 1
var prev2 = 1
var new = prev1 + prev2
for(i in 1..n){
new = prev1 + prev2
prev2 = prev1
prev1 = new
}
return
new
}

Now the recursive version:

fun fibonacciR(n: Int): Int = when(n){
1-> 1
2-> 1
else -> fibonacciR(n-1) + fibonacciR(n-2)
}

The code is much smaller but you see that is calculating everything twice.

You can see that this is using Head Recursion because the first fibonacciR call is not the last statement.

But, we can transform it into a tail-recursive version:

fun fibonacci(n: Int): Int = fibTail(n, 0, 1)fun fibTail(n: Int, prev2: Int, prev1: Int): Int =
when(n){
1-> prev1
else -> fibTail(n - 1, prev1, prev1 + prev2)
}

In this way, we are also sure we calculate each number only once.

The tail recursion optimization

In general, you can always transform a Head Recursive function in a Tail Recursive one adding one or more parameters.

Apart from the special case of Fibonacci double call, you may wonder, what would be the benefit of such a transformation?

The main problem is that for each time we call a function, even if it is the same, we need to add an entry to the current Thread Stack.

This is consuming memory and it’s also a relatively slow operation. If you do recurse more than a few thousands time, you will get the infamous StackOverflowError.

This is such a performance hit that kept people away from recursion in any serious code base.

Even Grady Brooch said we cannot use recursion for any serious problem, end of the story?

Not so fast!

It has been demonstrated that all Tail Recursion algorithms can be transformed in loops. Loops are hard to read but fast to execute and they don’t use the stack.

It would be really awesome if we could write the code in a nice recursive style and the compiler would transform for us in fast “loopy” code?

Wait a minute! Yes, we can!

In Kotlin, there is a keyword that instructs the compiler to do exactly this. Unimaginatively enough, the keyword is called tailrec.

Let’s try to add this keyword to our examples:

tailrec fun sumAll(acc: Int, x: Int):Int = when {
x <= 1 -> acc
else -> sumAll(acc+x, x-1)
}

The compiler generates automatically the ByteCode corresponding to this Java code:

public static final int sumAll(int acc, int x) {
while(
x > 1) {
int
var10000 = acc + x;
--x;
acc = var10000;
}

return
acc;
}

So what Kotlin compiler did is to translate the recursive call in a while loop.

The same thing happens if we add tailrec to the gcd function. The generated ByteCode converted in Java will look like this:

public static final int gcd(int n1, int n2) {
while(
n1 != n2) {
if (
n1 > n2) {
n1 -= n2;
} else {
n2 -= n1;
}
}
return
n1;
}

Note as in Java we can modify directly function parameters while in Kotlin we cannot, fortunately.

Finally, let’s try with tailrec Fibonacci

public static final int fibonacci(int n) {
return
fibTail(n, 0, 1);
}

public static final int
fibTail(int n, int prev2, int prev1) {
while(true) {
int
var10000;
switch(n) {
case
0:
var10000 = prev2;
break;
case 1:
var10000 = prev1;
break;
default:
var10000 = n - 1;
int var10001 = prev1;
prev1 += prev2;
prev2 = var10001;
n = var10000;
continue;
}

return
var10000;
}
}

Which is another loop…

So tailrec give us the opportunity to write code in a cleaner recursive way and having it compiled as a fast while loop.

Next time someone tell you that recursion has no place outside PhD thesis because of performance, you can tell him: “With all the due respect Sir, this is codswallop! I use Kotlin.”

After years of Java, I’m happy to start using again elegant recursive solutions in my production code, thanks to Kotlin compiler.

I strongly encourage you to explore the possibilities of recursive algorithms with tailrec optimization.

The Y Combinator

Finally, let's have a bit of fun exploring an interesting bit of functional programming theory regarding recursion.

We want to find the fixed point of a recursive function.

But first, what is the fixed point of a function?

In mathematics, a fixed point of a function is an element of the function’s domain that is mapped to itself by the function. Wikipedia

Let’s say that you have a function that returns the same type as the input parameter.

For example:

fun myFormula(x: Double): Double = x * 5 - 8fun square(x: Int): Int = x * xfun reverse(s: String): String = s.reversed()

Those functions are called endomorphisms.

The fixed point of an endomorphism is the input value (or values) that generate an identical result. In our examples:

myFormula(2) == 2square(1) == 1reverse("abba") = "abba"

Now back to recursion, what is the fixed point of a function that takes another function as input parameter?

In other words, what is the fixed point of function of type like these?

fun recursiveLong(f: (Long) -> Long): (Long) -> Longfun recursiveString(f: (String) -> String): (String) -> String

From lambda calculus, the fixed-point combinator for recursive functions is called the Y Combinator.

It is such a cool concept that the name has been adopted by a famous startup accelerator, I suppose each seed is some kind of recursion…

Is it possible to write a Y Combinator in Kotlin? Let’s try together!

What we want to do is:

  1. define a function f that takes another function as input and return a function of the same type
  2. call f passing itself as the parameter

Let’s start defining a typealias for a function that has the return and the input of the same type.

//Endomorphism: a function that return same type as the input
typealias
EndoMorph<A> = (A) -> A

We can define some recursive functions as EndoMorphisms:

fun fac(f: EndoMorph<Long>): EndoMorph<Long> = 
{ x -> if (x <= 1) 1 else x * f(x - 1) }
fun fib(f: EndoMorph<Int>): EndoMorph<Int> =
{ x -> if (x <= 2) 1 else f(x - 1) + f(x - 2) }
fun reverse(f: EndoMorph<String>): EndoMorph<String> =
{ s -> if (s.isEmpty()) "" else s.last() + f(s.dropLast(1)) }

Note how we changed a recursive function in a function which is not recursive strictly speaking, but it is using another function passed as parameter. Only it happens that the other function it is the function itself!

I find the whole affair mind warping but so intriguing…

Then to find the fixed point combinator, we can call f with f as a parameter. The simplest implementation would be like this:


fun <A> fix (f:(EndoMorph<A>) -> EndoMorph<A>): EndoMorph<A> =
f( fix( f ))

This is compiling but… if we try to use it with one of the previously defined functions:

fix(::fib)(10)vvvException in thread "main" java.lang.StackOverflowError

The problem is that the input parameter must be evaluated only when needed. In other words, we need to make it lazy:

fun <A> lazyFix (f:(() -> EndoMorph<A>) -> EndoMorph<A>): EndoMorph<A> = f( { lazyFix( f )}) 

This work but it require a different signature for recursive function:

fun lazyFib(f: () -> EndoMorph<Int>): EndoMorph<Int> = { x -> if (x <= 2) 1 else f()(x - 1) + f()(x - 2) }lazyFix(::lazyFib)(10) //55

Can we do better?

One possibility is to use a data class to lazily call the function when needed.

data class LazyFix<A>(val callIt: (LazyFix<A>) -> EndoMorph<A>) fun <A> fix( lazyFix: LazyFix<A>):EndoMorph<A> =  
lazyFix.callIt(lazyFix)
fun<A> yCombinator(recursiveFun:(EndoMorph<A>) -> EndoMorph<A>): EndoMorph<A> =
fix(LazyFix { rec -> recursiveFun{ x -> fix(rec)(x) } })

Let’s try to use it:

yCombinator(::fac)(19) //121645100408832000yCombinator(::fib)(17) //1597yCombinator(::reverse)("Recursion") //noisruceR

As you can see and test yourself, this is actually working. We made a recursive function out of a nonrecursive one!

Is there any practical application for this? None that I am aware of, still it is a quite interesting and challenging problem. Let me know if you find this useful for real coding.

The full code for these examples and more is on my GitHub project.

I hope you liked this post, if so please applaud it and follow me on Twitter and Medium.

--

--

JVM and Kotlin independent consultant. Passionate about Code Quality and Functional Programming. Author, public speaker and OpenSource contributor.