Inspired by the successful results of Haskell stream fusion (see Evolving Faster Haskell Programs (now with LLVM!) for some impressive optimizations) I was thinking if a similar concept is applicable to Scala collections. It turns out that with a combination of iterators and specialization it's possible to achieve similar optimizations in Scala.
The goal of stream fusion is essentially to optimize code like this:
def scalaLibrarySum(a : Array[Int]) = a.map(i => i * 3 + 7).filter(i => (i % 10) == 0).foldLeft(0)(_ + _)
into code like this:
def mapFilterSumLoop(a : Array[Int]) = {
var i = 0
var r = 0
while (i < a.length) {
val v = a(i) * 3 + 7
if ((v % 10) == 0)
r += v
i += 1
}
r
}
If you run the scalaLibrarySum method in Scala it will create two intermediate arrays with the results of the map and filter operations. This is totally unnecessary for this calculation as the functions passed to filter and map are side effect free and thus the result of the function applications can be performed lazily just before the result is needed in the fold operation. This is basically how the mapFilterSumLoop method works.
Besides creating intermediate arrays, boxing of primitive values must be avoided if we want to have any chance of competitive performance (the Haskell libraries contain specialized instances to avoid boxing). Fortunately Scala supports specialization of type parameters in version 2.8, which enables us to avoid boxing while still writing generic code. Unfortunately this feature seems to be quite buggy at the moment, just by playing around with a simple example I encountered two bugs (tickets #3148 and #3149). So, the code below contain some specialization done by hand. Hopefully these bugs will be fixed so that the code can be fully generalized.
The biggest difference compared to stream fusion in Haskell is that I use impure iterators in the Scala code. This is not as nice as the pure stream code used in Haskell, but it's a fact that Hotspot isn't nearly as good at optimizing pure functional code as GHC. Hotspot works best if fed imperative style loops.
Here's the definitions of the specialized functions and iterators I use in the benchmark below:
// Specialized Function1
trait Fn1[@specialized I, @specialized O] {
def apply(a : I) : O
}
// Specialized Function2
trait Fn2[@specialized I1, @specialized I2, @specialized O] {
def apply(a1 : I1, a2 : I2) : O
}
// Specialized iterator
trait SIterator[@specialized T] {
def hasMore : Boolean
def current : T
def next()
}
In addition to this I've defined array, filter and map iterators. Unfortunately these are not generic due to the problems with the specialize feature:
class IntArrayIterator(a : Array[Int], var index : Int, endIndex : Int) extends SIterator[Int] {
def next() = index += 1
def current = a(index)
def hasMore = index < endIndex
}
// Optimally this would be: class FilterIterator[@specialized T](iter : SIterator[T], pred : Fn1[T, Boolean]) extends SIterator[T]
class FilterIterator(iter : SIterator[Int], pred : Fn1[Int, Boolean]) extends SIterator[Int] {
def hasMore = iter.hasMore
def next() = {
iter.next()
findNext()
}
def findNext() = {
while (iter.hasMore && !pred(iter.current))
iter.next()
}
def current = iter.current
findNext()
}
// Optimally this would be: class MapIterator[@specialized U][@specialized T](iter : SIterator[T], fn : Fn1[T, U]) extends SIterator[U]
class MapIterator(iter : SIterator[Int], fn : Fn1[Int, Int]) extends SIterator[Int] {
def next() = iter.next()
def current = fn(iter.current)
def hasMore = iter.hasMore
}
The fold function is straightforward and generic:
def fold[@specialized T, @specialized U] (iter : SIterator[T], fn : Fn2[U, T, U], v : U) = {
var r = v
while (iter.hasMore) {
r = fn(r, iter.current)
iter.next()
}
r
}
The map-filter-sum function can now be written using iterators:
def mapFilterSum(a : Array[Int]) = {
val filter = new Fn1[Int, Boolean] {def apply(a : Int) = (a % 10) == 0}
val map = new Fn1[Int, Int] {def apply(a : Int) = a * 3 + 7}
val s = new FilterIterator(new MapIterator(new IntArrayIterator(a, 0, a.length), map), filter)
fold(s, new Fn2[Int, Int, Int] {def apply(a1 : Int, a2 : Int) = a1 + a2}, 0)
}
The full iterator code can be found here. Compile the code using the latest Scala 2.8 build with the -Yspecialize flag. The optimize flag doesn't seem to have much effect on the performance.
I've benchmarked four different implementations of the map-filter-sum calculation:
- The while loop shown above
- The while loop split up into map, filter and fold functions with intermediate array results passed between them
- The version using specialized iterators
- The Scala library implementation shown above
- Same as Scala library function but with a view instead
- Same as Scala library function but with a stream instead
The benchmark is performed by taking the minimum execution time of 200 runs of each of the functions on an array of 1 million integers. Running the application with latest OpenJDK 7 (Hotspot version "build 17.0-b10") and the flags "-server -XX:CompileThreshold=100" I get the following results:
Loop: (4990,-423600172)
Loop with intermediate arrays: (6690,-423600172)
Specialized iterators: (5367,-423600172)
Scala array: (46444,-423600172)
Scala view: (39625,-423600172)
Scala stream: (63210,-423600172)
The first result value is the minimum execution time in microseconds, the second value is the result of the calculation. As you can see the method using specialized iterators is almost as fast as the single while loop. Hotspot has inlined all the iterator code, not bad! Using intermediate arrays is about 25% slower than specialized iterators. Using the Scala library is about 7-9 times slower! Clearly this bad result is a consequence of boxing taking place. Using a view is fastest here as it also avoids intermediate array creation.
The conclusion from this simple experiment is that it's certainly possible to write collection libraries with a nice high level interface and at the same time have excellent performance. When Scala specialization support is improved hopefully this power be available to all Scala programmers.
The full benchmark code can be found here.