Recursion and Chasing your Own Tail

Atul S. Khot

December 2015

 In this article by Atul S. Khot, the author of the book Scala Functional Programming Patterns, we will focus on the recursion concepts more closely and see how all these help us write succinct code, and how going recursive promotes immutability. In this chapter, we will first look at recursive structures—a structure is recursive if the shape of the whole recurs in the shape of the parts. We will then look at how Scala's pattern matching helps us to work on the composing parts. Next, we will take a look at a possible problem with very large structures and the mechanism to deal with them—namely tail call optimization (TCO) and @tailrec annotations. Finally, we will get a handle on persistent data structures and structural sharing to avoid needless copying while preserving immutability.

Recursivestructures

The find command on Linux (and the dir /s command on Windows) recursively descends into a directory; if there are a few subdirectories within command, then it descends into each subdirectory, one by one. If the subdirectories, in turn, have subdirectories, command goes into each one and repeats the process all over again till all the directories are traversed. Let's have a look at the following directory:


Figure 3.1: A directory tree is a recursive structure

Given this directory, try the following command:

 % find ./tmp -type f -exec wc -c {} \;

The find command starts at the tmp directory and applies the wc command to each regular file (so for this example, skip directories).

The command enters in tmp and finds a and c. As these are directories, the flow enters a first, and finds b and one.txt. As directory b is empty, it looks at one.txt for which the predicate type f is true. So, the characters are counted for one.txt, and then the flow comes back to a and recurs into c. The process continues till every node in the directory tree is visited; also, every node is visited once and only once. Now, if you look carefully, when we come to node a, we have the same problem to solve as when we started with tmp. This problem is inherently recursive. This is the essence of recursion—we keep reducing the problem by dividing the dataset into smaller and smaller pieces. At some point though, we need to look at solving the problem (counting characters in regular files in our case). In our case, when we don't have any more directories to recur into, forms a base case with following cases:

  • A subdirectory is empty—the flow just returns in this case.
  • The note is not a directory at all but a regular file (one.txt). We perform the operation (count characters) and return.

The sub-problems that are solved directly without dividing it any further. Such base cases allow the algorithm to terminate eventually.

In the previous chapter, we looked at the binary tree traversal method. When the traversal flow hits the Null object, it terminates the traversal. This is a base case. Similarly, when the insertion flow hits the Null object, it adds the new node, forming another base case. When the traversal hits an intermediate node, we have a recursive case.

Pattern matching

Slice and dice is defined as the process of breaking something down (for example, information) into smaller parts to examine and understand it. You can get more information about slice and dice at

http://dictionary.reference.com/browse/slice-and-dice.

Let's see this technique in action. We will try to count the number of elements in List. There is already a length method defined on Lists:

scala> List(1,2,3).length
res0: Int = 3
Rolling out one of our own list teaches us something:
object Count extends App {
def count(list: List[Int]): Int = list match {
  case Nil => 0  // 1
  case head :: tail => 1 + count(tail) // 2
}
val l = List(1,2,3,4,5)
println(count(l)) // prints 5
}

The preceding code counts the number of elements in the list. We match the list against two possible cases, which are as follows:

  • The base case: The list is Nil, and an empty list matches Nil as it has zero elements, we return 0.
  • The general case: This is a list that has one or more elements. A list with at least one element (head) plus possibly more elements (tail), we very well could have none. We don't know (as yet). So, we call the same method recursively with the tail. Here is the process shown pictorially:


Figure 3.2: Losing head at every iteration

Note how we've left head aside in the preceding figure. We have taken the head value into account; in this case, we've incremented the count by 1. When we call the method recursively, we have one less element to process. This losing of the head and reducing finally land us with the case 1—letting us terminate the processing eventually. We are iterating the list and visiting each node, albeit in a different way. Here, we don't use any mutation. There are no variables (no var keyword) used as loop counters. Recursion promotes immutability. Immutability is a boon when we write concurrent code, as we will soon see in the following chapters.

Deconstruction with case statements

Here is how we de-structure a list to get at the first element of the list. The following case clause splits the list into the first element (the head) and the rest of the list (the tail):

case head :: tail => 1 + count(tail)

These case matches when it is matched against a list with at least one element. Open the REPL and type the following code:

scala> val head :: tail = List(1, 2, 3, 4) // 1
head: Int = 1
tail: List[Int] = List(2, 3, 4)
scala> val (x,y) = (7, 10) // 2
x: Int = 7
y: Int = 10

Salient points:

  • Deconstructs a list into its head, 1, and its tail, List(2, 3, 4).
  • Deconstructs a pair into its constituent values—assigns 7 to x and 10 to y

And the case clause using underscore is as shown:

scala> List(1, 2, 3, 4) match {
     |   case head :: tail => println(head) // 3
     |   case _ => println("Nothing")
     | }

The case clause with just an underscore. We use it as an unnamed variable, just a placeholder.

The preceding command will print 1.

The :: symbol is a List extractor (refer to http://www.artima.com/pins1ed/extractors.html for more information on extractors). The x :: y expression results in a call to the unapply method of the :: object. You did read it right. The :: symbol is a case class and its companion object is::. Like the apply method that we saw earlier, the Scala compiler will call unapply in a pattern matching expression. When we define a case class, we get both the apply and unapply methods written for us:

scala> case class MyClass (x: Int, y: Int)
defined class MyClass
scala> val p = MyClass(1, 2);
p: MyClass = MyClass(1,2)
scala> p match {
     |   case MyClass(a, b) => println(s"a=$a, b=$b")  // 1
     | }
a=1, b=2
scala> p match {
     |   case a MyClass b => println(s"a=$a, b=$b")   // 2
     | }
a=1, b=2

At the part in the code labeled as 2, the extractor expression is written in the infix form:

scala> val p = List(1, 2, 3, 4)
p: List[Int] = List(1, 2, 3, 4)
scala>  p match {
     |     case ::(head, tail) => println(s"head=$head, tail=$tail") // 1
     |     case _ => println("What's up???")
     |  }
head=1, tail=List(2, 3, 4)

At 1, the unapply method of :: is called. The statement is just rewritten in an infix notation as follows:

case head :: tail => println(s"head=$head, tail=$tail")

Stack overflows

Our recursive solution works, however there is a problem waiting to strike us. The code works fine for small lists with a few elements. Let's stress test it with a big list that has 20000 elements:

  • Call the count method as shown in the following code:
    val l  = 1 to 20000 // A range object
    count(l.toList) // Converts the range into a list
  • Run the code, and you will get the java.lang.StackOverflowError error. The problem here is the recursive call 1 + count (tail).

Each intermediate context is remembered on a stack frame. The intermediate context here is Get me a count of the tail, and add one to it. How many such intermediate contexts are there? I think you already guessed it right, they are equal to the number of elements in a list.

In other words, the numbers of contexts to remember are proportional to n. In algorithmic sense, these are equal to O(n). So for this example list, we need 20,000 stack frames for a list that has 20,000 elements; so, we need these many stack frames. The system usually cannot allocate these many. Hence, the routine looks broken. Now, what good is the technique if it does not work for large lists? We are in a logjam, as you can see in the following figure:


Figure 3.3: An example of stack overflow

Tail recursion to the rescue

There is a technique, an optimization, that really helps us get out of the logjam. However, we need to tweak the code a bit for this. We will make the recursive call as the last and only call. This means that there is no intermediate context to remember. This last and only call is called the tail call. Code in this tail call form is amenable to TCO. Scala generates code that, behind the scenes, uses a loop—the generated code does not use any stack frames:

import scala.annotation.tailrec
def count(list: List[Int]): Int = {
  @tailrec   // 1
  def countIt(l: List[Int], acc: Int): Int = l match {
    case Nil => acc // 2
   case head :: tail => countIt(tail, acc+1) // 3
  }
  countIt(list, 0)
}

The changes are like this:

We have a nested workhorse method that is doing all the hard work. The count method calls the countIt nested recursive method, with the list it got in the argument, and an accumulator. The earlier intermediate context is now expressed as an accumulator expression with the help of the following steps:

  1. The @tailrec annotation makes sure that we are doing the right things so that we benefit from the TCO. Without @tailrec, Scala may apply the optimization or it may not.The @tailrec annotation is a helping hand to ensure that the function call is optimized. Try the same annotation on our first version. You will get a compilation error.
  2. Play time: Change the second case in the previous code as shown:
    case head :: tail => {
      val cnt: Int = countIt(tail, acc + 1)
        println(cnt) 
        cnt
    }
  3. You will get a compilation error after changing the second case:
    Error: could not optimize @tailrec annotated method countIt: it contains a recursive call not in tail position
  4. When the execution lands in this clause, we are at the end of the list. We are not going to find any more elements, so the base case just returns the accumulator.
  5. There is no intermediate context now—just a tail call. This is the only and last recursive call. We found one more element. We increment the accumulator to record the fact and pass it on.

Compared to the earlier version, why does this version work? If the compiler can use tail call optimization, then it does not need to stack up the context. So, no stack frames are needed as the resulting executable code uses loops behind the scenes.

Getting the nth element of a list

A list holds a certain number of elements. The first element is at index 0, and the second element at index 1. If the index is out of range, we get an exception. 

We will write a method to find the nth element of a list. This method will return an option. If n is out of bounds, we will return None. Otherwise, we will return Some(elem). Let's look at the code and then a diagram to understand it better:

import scala.annotation.tailrec
object NthElemOfList extends App {
 def nth(list: List[Int], n: Int): Option[Int] = {
  @tailrec
  def nthElem(list: List[Int], acc: (Int, Int)): Option[Int] = list match {
    case Nil => None
    case head :: tail => {
     if (acc._1 == acc._2)     // 1
     Some(head)   
    else
      nthElem(tail, (acc._1 + 1, acc._2))     // 2
   }
   }
  nthElem(list, (0, n))   // 3
 }
 val bigList = 1 to 100000 toList  // 4
 println(nth(List(1, 2, 3, 4, 5, 6), 3).getOrElse("No such elem"))
 println(nth(List(1, 2, 3, 4, 5, 6), 300).getOrElse("No such elem"))
 println(nth(bigList, 2333).getOrElse("No such elem"))
}

Here is a diagrammatic representation of the flow:

Figure: 3.4: Conceptual stack frames

The description of the preceding diagram is as follows:

  • Our accumulator is a pair. The first element is a running index and holds the index of the current node, starting from 0. The second element is always fixed at n.
  • If index == n, we have found the element. Now, we will return it wrapped in a some. else we increase the running index and recursively call the method again on the tail of the list.
  • We wish to hide the accumulator from the client code. Keeping an accumulator is an implementation detail.
  • We will test the index with a very big list as @tailrec is in effect, the TCO kicks in and we don't get any stack overflows. Try to rewrite the preceding code without using a pair of element for the accumulator. Compare your version with it and check which one is better.

The following are a few points that we need to ponder:

  • Could we simply use acc as Int? Do we really need the pair?
  • Try writing the code so that we decrement the accumulator instead of incrementing it.

An expression parser

We will look at an instructive example of how recursion and immutability go hand in hand. We will look at an infix expression parser.

An infix notation is where the operator comes in between two operands, for example, 3+4+5-6 is the infix notation.

We will look at the iterative Java version and then at the recursive Scala version. We support only operators + and * and also provide bracketed sub-expressions.

Evaluating (1+2)*3*(2+4) expression should give us the output as 54 and evaluating (1+2)*3+4 expression should give us the output as 13. The grammar for our expression parser looks as shown in the following code. Note how each sub-expression is an expression composed of other sub-expressions, terms, and factors. In short, the grammar is recursively defined. Here is the grammar:

Expr: Term | Term + Expr
Term: Factor | Factor * Term
Factor: [0-9][0-9]+ | '(' Expr ')'

Here is a diagrammatic representation of the flow:

Figure 3.5: The expression tree

Look at the bracketed expression node in the bottom-right corner of the preceding image. The grammar also tells us that multiplication (*) is given precedence over addition (+). A bracketed expression has the highest precedence. Here is the Java code which makes use of an tokenizer:

import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.Validate;
public class Parser {
 private static final String L_BRACKET = "(";

 private static final String R_BRACKET = ")";
 private class Tokenizer {
 private static final String NUM_PATTERN = "(\\d+).*";
 private String s;
 private final Pattern p;
 public Tokenizer(final String s) {
 this.s = s;
 this.p = Pattern.compile(NUM_PATTERN);
  }
   public boolean nextTokenIs(final String tok) {
   if (s.isEmpty()) {
  return false;
  }    return s.startsWith(tok);
  } public void consume(final String tok) { s = s.replaceFirst(Pattern.quote(tok), "");
} public boolean nextTokenIsNumber() {
return s.matches(NUM_PATTERN);
} public int consumeANumber() { if (s.matches(NUM_PATTERN)) { final Matcher m = p.matcher(s); Validate.isTrue(m.matches(), "Could not extract number from <"+ s + ">");   final String numStr = m.group(1);   s = s.replaceFirst(Pattern.quote(numStr), "");   return Integer.valueOf(numStr);          }   throw new IllegalArgumentException("Number expected");   }  };   private final Tokenizer tokenizer;   public Parser(final String s) {   tokenizer = new Tokenizer(s);  }   private int factor() {   if (tokenizer.nextTokenIsNumber()) {   final int num = tokenizer.consumeANumber(); // 1   return num;   }   if (tokenizer.nextTokenIs(L_BRACKET)) {    tokenizer.consume(L_BRACKET); // 2    final int num = expr();    if (!tokenizer.nextTokenIs(R_BRACKET)) {     throw new IllegalArgumentException("Syntax error - ) missing");    }    tokenizer.consume(R_BRACKET);  // 3    return num;   }   throw new IllegalArgumentException("Either number or (expected");  }  private int term() {   int val = factor();     // 4   while (tokenizer.nextTokenIs("*")) { // 5    tokenizer.consume("*");    val *= factor(); //  6   }   return val;  }  public int expr() {   int val = term();           //  7   while (tokenizer.nextTokenIs("+")) {   // 8         tokenizer.consume("+");    val += term();   // 9   }   return val;  } }

We have a very simple tokenizer class; it splits an input string into tokens, for example, given the expression (111+222), it generates (, 111, +, 222, ). Although it is a pretty simple tokenizer, it is sufficient for our needs. We use Java's regular expression matching facilities to tokenize our string with the help of the following steps:

  1. The code obeys the preceding grammar. A factor is either a number or a subexpression starting with '('.
  2. If it is a number, we consume and return it. If it is a sub-expression instead, we process the sub-expression first. So given the expression (9+4)*4, we first evaluate 9+4.
  3. Once we have reduced the subexpression into a number, we consume ')', thereby consuming all the bracketed subexpressions, and return the number to the caller.
  4. After the bracketed subexpression, multiplication has a higher precedence. However, any term (multiplication) starts with a factor.
  5. Once we have consumed a factor, if the next token is a *, we consume it too.
  6.  We keep looking for more terms. Once all the terms are reduced to a number, we return the resulting number.
  7. This is the topmost method. It tries to reduce to a term. Term could be a simple number though.
  8. While we have + next in the stream, we keep looking for more addition expressions.
  9. We keep adding the reduced values to get the value of the overall expression.

Try the input (1+2)). How would you fix it?
Try rewriting the code using recursion. What could we say about TCO the Java Virtual Machine?

Here is the code for the Scala version:

import scala.annotation.tailrec
object Parser extends App {
  val Number = """^(\d+).*""".r      // 1
  val LParen = """^[(].*""".r
  def factor(f: String): (String, Int) = f match {
    case Number(d) => (f.drop(d.length), d.toInt) // 2
    case LParen(_*) => {            
      val p = expr(f.drop(1), 0)  // 3
      val e = p._1
      if (e.take(1) == ")") {  // 4
        (e.drop(1), p._2)
      } else {
        throw new IllegalArgumentException("Right bracket missing")
      }
    }
    case _ => throw new IllegalArgumentException("Number or sub-expression expected")
  }
  @tailrec
  def term(t: String, acc: Int): (String, Int) = {
    val p = factor(t)
    val e = p._1
    if (e.take(1) == "*") {                 // 5
      term(e.drop(1), acc * p._2)       // 6
    } else {
      (e, acc * p._2)                           // 7
    }
  }
  @tailrec
  def expr(s: String, acc: Int): (String, Int) = {
    val p = term(s, 1)
    val e = p._1
    if (e.take(1) == "+") {                  // 8
      expr(e.drop(1), acc + p._2)         // 9
    } else {
      (e, acc + p._2)                             // 10
    }
  }
  def expr(s: String): Int = {
    val e = expr(s, 0)
    e._2
  }
  val p = expr("(1+2)*3*(2+4)")
  println(p)
}

Since the execution flow is a bit involved, the following diagram will help you understand it..

Figure 3.6: The flow for simple expression (1+2)

We will create a regular expression pattern—Regex—to match a number. We use a multiline string, so we don't need to escape the backslash character. To match digits, we simply use \d instead of \\d.

To match and extract either a number or '(', we use Regex as an extractor in a pattern match:

scala> val num = """(\d+)([.]\d+)?""".r
num: scala.util.matching.Regex = (\d+)([.]\d+)?
scala> "101.22" match {
     |   case num(decimal, fractional) => s"decimal = $decimal, fractional = $fractional"
     | }
res0: String = decimal = 101, fractional = .22

Let's dissect the code.

Salient Points:

  1. This is a call to evaluate a sub-expression that is enclosed in brackets. For 2*(3+3), the 3+3 addition precedes multiplication.
  2. The closing ')' bracket is consumed. This indicates the completed evaluation of a bracketed sub-expression.
  3. We are in the middle of a term.
  4. A term keeps evaluating itself and other subsequent terms. Refer to the grammar diagram.
  5. Now, we that are done evaluating a term, we can return the remaining string and the result of the term evaluation.
  6. We are in the middle of an expression.
  7. Keep evaluating—giving precedence to a term evaluation.
  8. Finally, return the pair of the leftover string and the result of the overall expression.

Take a somewhat bigger expression and work through the code.

When we mix string and regular expressions, we often see the leaning toothpick syndrome. The regular expression notation, \d, matches a digit character. However, the backslash also works as an escape character. You need to double the backslash so that the Regex engine can see it, for example:

scala> val regex = "H\\dllo".r
regex: scala.util.matching.Regex = H\dllo
scala> regex findFirstIn("H1llo H2llo")
res0: Option[String] = Some(H1llo)

Scala's triple quote strings allow us to express the regular expression in a natural way. Now, try using the following expression in the preceding code snippet:

val regex = """H\dllo""".r

Persistent data structures

As you can guess by now, immutability is the underlying big theme. The following Java code reverses a list in place:

List<Integer> list = Lists.newArrayList(1,2,3,4);
// List<Integer> refList = Lists.newArrayList(list); // 1
List<Integer> refList = list;
Collections.reverse(list);
System.out.println(list);
System.out.println(refList);

The problem is when we do the reversal in place, the code using the list as a reference also sees the change. To prevent this, we need to use the statement at the part labeled as 1—the defensive copy technique. The other problem with changing the list in place is thread safety. We need to carefully synchronize the list access so we stay away from heisenbugs. To know more about them, refer to the following URL:

http://opensourceforu.efytimes.com/2010/10/joy-of-programming-types-of-bugs/

Scala, instead, advocates immutable lists; we cannot change the list structure in place; instead, we create a new list:

import scala.annotation.tailrec
object ReverseAList extends App {
@tailrec
def reverseList(list: List[Int], acc: List[Int]) : List[Int] = list match {
  case head :: tail => reverseList(tail, head :: acc)
  case Nil => acc
}
val l = 1 to 20000 toList
println(reverseList(l, Nil))
}

We will again use the accumulator idiom to make the list tail recursive. As Scala's List is immutable, we need to create a new list each time a new node is added. You may ask, won't it be expensive to create a brand new list each time? Not really, as the lists are immutable and they could be structurally shared as shown in the following figure.

Figure 3.7: An example of lists

As shown pictorially in the preceding diagram, we can traverse list 2. As the node (value 3) is added, both list 3 and list 2 can share the node with the value 1 and the node with value 2. As both these lists are immutable, the nodes could be safely shared.

Such a data structure that always preserves the previous version of itself when it is modified is a persistent data structure. And no, it has nothing to do with persistence as in disk/database persistence.

Let's look at the following list concatenation:

val l1 = List(1,2,3)
val l2 = List(4,5,6)
val l3 = l1 ++ l2  // List(1,2,3,4,5,6)

Here, we need to copy l1 nodes and structurally share l2 nodes.

We cannot just change l1 as it is immutable. We need to copy l1 nodes to l3 and change the third node to point at l2 so that anyone who is already referring to l1 is not affected.

Now, let's try our hand at an example. In following diagram, link up the left-hand side dangling pointer of the node with value 4 so that the tree is structurally shared. Draw the shared structure after inserting 22 in the tree:

Figure 3.8: Persistent tree after inserting node 100

Two forms of recursion

In the previous sections, we saw the tail recursive code to reverse a list. Take a look at this form:

object ReverseAList1 extends App {
 def reverseList(list: List[Int]): List[Int] = list match {
 case head :: tail => reverseList(tail) :+ head
 case Nil => Nil
 }
 val l = (1 to 20000).toList
 println(reverseList(l))
}

I know. This form is not tail recursive. Hence, this form will not benefit from the tail call optimization. Try to put the @tailrec annotation on the reverseList method. You will get a compilation error.

This form is still useful though. There are times when we do not want all the list elements. We just want to look at the first few elements of the result. We want to make the recursive call evaluation deferred. Call is not computed upfront. Instead, it is only evaluated when needed. We are talking of delayed evaluation in this case. Scala's Streams implements lazy lists where elements are only evaluated when they are needed.

Summary

In this article, you learned about recursion and problems that are by nature, recursive. For example, directories on a Linux filesystem are defined this way. You also learned about recursive solutions, the general case, and the essential base case. The base case is needed so that the process eventually gets terminates. We also looked at Scala's slice-and-dice technique and how to split the list so that we can visit each element. We saw how recursive calls and associated context are remembered on stack frames and the need for tail recursion. We saw examples where tail recursion enables TCO. We looked at a detail example of a small expression parser. We looked at how recursive code promotes immutability, as we did not use var in our code. We looked at persistent data structures and the two forms of recursion. Let's now move on to the wonders of Scala's power features for delayed evaluation.

You've been reading an excerpt of:

Scala Functional Programming Patterns

Explore Title