Thursday 4 February 2010

Scala: Getting into performance trouble, calling head and tail on an ArrayBuffer

Update: The performance problem described below will be remedied in the final release of Scala 2.8. See martin's comment.

====================================

Recently, I wrote the following two different versions for doing the same thing (compute frequencies):

// Version 1  --- Don't do this, lousy performance
// Scala 2.8
def freq[T](seq: Seq[T]): Map[T, Int] = {
import annotation._
@tailrec
def freq(seq: Seq[T], map: Map[T, Int]): Map[T, Int] = {
seq match {
case s if s.isEmpty => map
case s => {
val elem = s.head
val n = map.getOrElse(elem, 0) + 1
freq(s.tail, map + (elem -> n ))
}
}
}
freq(seq, Map())
}

// Version 2 --- 260 times faster than Version 1 on some input
def freq[T](seq: Seq[T]): Map[T, Int] = {
val freqs = collection.mutable.HashMap[T, Int]()
for(elem <- seq) {
val n = freqs.getOrElseUpdate(elem, 0)
freqs.update(elem, n + 1)
}
// Return immutable copy of freqs
Map() ++ freqs
}


When comparing the two versions, it turned out that for some input, Version 1 was about 260 times slower (after JVM warm-up). The performance difference surfaced when both versions were called with the following different inputs:

val linesList = io.Source.fromPath("testfile.txt").getLines().toList
val linesSeq = io.Source.fromPath("testfile.txt").getLines().toSeq

Version 1 called with linesSeq as input, performes horrlibly compared to when called with linesList. On my own, I couldn't figure out why, but helpful and knowledgeable people at #scala solved my problem in a few seconds. The explanation appears to be that 1) The default implementation of Seq is an ArrayBuffer, and 2) Calling head and tail on an ArrayBuffer is costly. The same operations are cheap on a List. That's why Version 1 above is a performance trap.

A possible way of getting better performance, is to change the inner, two argument, freq method to use List, instead of Seq:

// Version 1.b  --- Somewhat better
// Scala 2.8
def freq[T](seq: Seq[T]): Map[T, Int] = {
import annotation._
@tailrec
def freq(seq: List[T], map: Map[T, Int]): Map[T, Int] = {
seq match {
case s if s.isEmpty => map
case s => {
val elem = s.head
val n = map.getOrElse(elem, 0) + 1
freq(s.tail, map + (elem -> n ))
}
}
}
freq(seq.toList, Map())
}


Better yet --- in Scala 2.8 --- is to scrap the entire method, and call groupBy(identity).mapValues(_.length) directly on the Seq...

Monday 1 February 2010

Counting Strings and Things in Scala (2.8)

I often need to count the frequencies of strings ("words", typically). Below are a few Scala snippets for counting strings and things. (Don't miss the last one.)

First try

Let's start with a method for counting string frequencies in a list:

// Scala 2.8
def freq(wds: List[String]): Map[String, Int] = {
import annotation._
@tailrec
def freq(wds: List[String], map: Map[String, Int]): Map[String, Int] = {
wds match {
case l if l.isEmpty => map
case l => {
val elem = l.head
val n = map.getOrElse(elem, 0) + 1
freq( l.tail, map + (elem -> n ) )
}
}
}
freq(wds, Map())
}

It takes a list of strings, and returns a map (hash table) with a frequency count for each unique string. The one argument freq method contains an embedded two argument freq method. The second method recursively consumes elements of the list, incrementing the frequency count of the second accumulator argument. The two argument method is initialised with an empty map (at the end of the one argument method, freq(wds, Map()).

In each recursion, a new, immutable word frequency map is produced, with the incremented frequency count. The
import annotation._
@tailrec
part tells the compiler to check whether it can optimize the tail recursive call or not. (The Scala compiler can optimize a special case of tail recursion.)

If the recursion makes you dizzy, you can use a mutable HashMap instead:
def freq(wds: List[String]): Map[String, Int] = {
val freqs = collection.mutable.HashMap[String, Int]()
for(w <- wds) {
val n = freqs.getOrElseUpdate(w, 0)
freqs.update(w, n + 1)
}
// Return immutable copy of freqs
Map() ++ freqs
}

Second try

You'll soon find out that the above code is limited, since it only accepts List input. There is a more general concept, Seq, that will make it possible to call freq with different kinds of sequences (lists, listbuffers, arrays):
// Scala 2.8
def freq(wds: Seq[String]): Map[String, Int] = {
import annotation._
@tailrec
def freq(wds: Seq[String], map: Map[String, Int]): Map[String, Int] = {
wds match {
case l if l.isEmpty => map
case l => {
val elem = l.head
val n = map.getOrElse(elem, 0) + 1
freq( l.tail, map + (elem -> n ) )
}
}
}
freq(wds, Map())
}

Third try

One day you find yourself relocated from the word counting department to the character counting department. A string is a sequence, but of Chars, not Strings. The code above will not help you count character frequencies. Here is an attempt at generalising the code further, to make it able to count the frequencies of any thing, T, not just String:
// Scala 2.8
def freq[T](seq: Seq[T]): Map[T, Int] = {
import annotation._
@tailrec
def freq(seq: Seq[T], map: Map[T, Int]): Map[T, Int] = {
seq match {
case s if s.isEmpty => map
case s => {
val elem = s.head
val n = map.getOrElse(elem, 0) + 1
freq(s.tail, map + (elem -> n ))
}
}
}
freq(seq, Map())
}



Here's the more general non-tail-recursive version:
def freq[T](seq: Seq[T]): Map[T, Int] = {
val freqs = collection.mutable.HashMap[T, Int]()
for(elem <- seq) {
val n = freqs.getOrElseUpdate(elem, 0)
freqs.update(elem, n + 1)
}
// Return immutable copy of freqs
Map() ++ freqs
}

Hooray.

Last try (shamelessly lifted from someone at #scala)

But... you can still do better than this. A while ago, someone on the #scala irc channel (unfortunately, I don't remember this persons name) answered a question on how to associate each integer in a sequence with the number of times each integer occurred (or something like that). It turns out that, in Scala 2.8, it is possible to write a frequency counting thing even more compactly:
def freq[T](seq: Seq[T]) = seq.groupBy(x => x).mapValues(_.length)
It's so short, that it's almost not worth defining a method/function for it. You can simply call .groupBy(x => x).mapValues(_.length) directly on your Seq. (Or groupBy(identity).mapValues(_.length), which is the same thing.)

Double hooray.

Benchmarking is tricky, but a small test indicates that the last, most beautiful, version is also the quickest, and that the recursive ones using only immutable maps (in some situations) are quite slow.