How do I apply the enrich-my-library pattern to Scala collections?

The key to understanding this problem is to realize that there are two different ways to build and work with collections in the collections library. One is the public collections interface with all its nice methods. The other, which is used extensively in creating the collections library, but which are almost never used outside of it, is the builders.

Our problem in enriching is exactly the same one that the collections library itself faces when trying to return collections of the same type. That is, we want to build collections, but when working generically, we don’t have a way to refer to “the same type that the collection already is”. So we need builders.

Now the question is: where do we get our builders from? The obvious place is from the collection itself. This doesn’t work. We already decided, in moving to a generic collection, that we were going to forget the type of the collection. So even though the collection could return a builder that would generate more collections of the type we want, it wouldn’t know what the type was.

Instead, we get our builders from CanBuildFrom implicits that are floating around. These exist specifically for the purpose of matching input and output types and giving you an appropriately typed builder.

So, we have two conceptual leaps to make:

  1. We aren’t using standard collections operations, we’re using builders.
  2. We get these builders from implicit CanBuildFroms, not from our collection directly.

Let’s look at an example.

class GroupingCollection[A, C[A] <: Iterable[A]](ca: C[A]) {
  import collection.generic.CanBuildFrom
  def groupedWhile(p: (A,A) => Boolean)(
    implicit cbfcc: CanBuildFrom[C[A],C[A],C[C[A]]], cbfc: CanBuildFrom[C[A],A,C[A]]
  ): C[C[A]] = {
    val it = ca.iterator
    val cca = cbfcc()
    if (!it.hasNext) cca.result
    else {
      val as = cbfc()
      var olda = it.next
      as += olda
      while (it.hasNext) {
        val a = it.next
        if (p(olda,a)) as += a
        else { cca += as.result; as.clear; as += a }
        olda = a
      }
      cca += as.result
    }
    cca.result
  }
}
implicit def iterable_has_grouping[A, C[A] <: Iterable[A]](ca: C[A]) = {
  new GroupingCollection[A,C](ca)
}

Let’s take this apart. First, in order to build the collection-of-collections, we know we’ll need to build two types of collections: C[A] for each group, and C[C[A]] that gathers all the groups together. Thus, we need two builders, one that takes As and builds C[A]s, and one that takes C[A]s and builds C[C[A]]s. Looking at the type signature of CanBuildFrom, we see

CanBuildFrom[-From, -Elem, +To]

which means that CanBuildFrom wants to know the type of collection we’re starting with–in our case, it’s C[A], and then the elements of the generated collection and the type of that collection. So we fill those in as implicit parameters cbfcc and cbfc.

Having realized this, that’s most of the work. We can use our CanBuildFroms to give us builders (all you need to do is apply them). And one builder can build up a collection with +=, convert it to the collection it is supposed to ultimately be with result, and empty itself and be ready to start again with clear. The builders start off empty, which solves our first compile error, and since we’re using builders instead of recursion, the second error also goes away.

One last little detail–other than the algorithm that actually does the work–is in the implicit conversion. Note that we use new GroupingCollection[A,C] not [A,C[A]]. This is because the class declaration was for C with one parameter, which it fills it itself with the A passed to it. So we just hand it the type C, and let it create C[A] out of it. Minor detail, but you’ll get compile-time errors if you try another way.

Here, I’ve made the method a little bit more generic than the “equal elements” collection–rather, the method cuts the original collection apart whenever its test of sequential elements fails.

Let’s see our method in action:

scala> List(1,2,2,2,3,4,4,4,5,5,1,1,1,2).groupedWhile(_ == _)
res0: List[List[Int]] = List(List(1), List(2, 2, 2), List(3), List(4, 4, 4), 
                             List(5, 5), List(1, 1, 1), List(2))

scala> Vector(1,2,3,4,1,2,3,1,2,1).groupedWhile(_ < _)
res1: scala.collection.immutable.Vector[scala.collection.immutable.Vector[Int]] =
  Vector(Vector(1, 2, 3, 4), Vector(1, 2, 3), Vector(1, 2), Vector(1))

It works!

The only problem is that we don’t in general have these methods available for arrays, since that would require two implicit conversions in a row. There are several ways to get around this, including writing a separate implicit conversion for arrays, casting to WrappedArray, and so on.


Edit: My favored approach for dealing with arrays and strings and such is to make the code even more generic and then use appropriate implicit conversions to make them more specific again in such a way that arrays work also. In this particular case:

class GroupingCollection[A, C, D[C]](ca: C)(
  implicit c2i: C => Iterable[A],
           cbf: CanBuildFrom[C,C,D[C]],
           cbfi: CanBuildFrom[C,A,C]
) {
  def groupedWhile(p: (A,A) => Boolean): D[C] = {
    val it = c2i(ca).iterator
    val cca = cbf()
    if (!it.hasNext) cca.result
    else {
      val as = cbfi()
      var olda = it.next
      as += olda
      while (it.hasNext) {
        val a = it.next
        if (p(olda,a)) as += a
        else { cca += as.result; as.clear; as += a }
        olda = a
      }
      cca += as.result
    }
    cca.result
  }
}

Here we’ve added an implicit that gives us an Iterable[A] from C–for most collections this will just be the identity (e.g. List[A] already is an Iterable[A]), but for arrays it will be a real implicit conversion. And, consequently, we’ve dropped the requirement that C[A] <: Iterable[A]–we’ve basically just made the requirement for <% explicit, so we can use it explicitly at will instead of having the compiler fill it in for us. Also, we have relaxed the restriction that our collection-of-collections is C[C[A]]–instead, it’s any D[C], which we will fill in later to be what we want. Because we’re going to fill this in later, we’ve pushed it up to the class level instead of the method level. Otherwise, it’s basically the same.

Now the question is how to use this. For regular collections, we can:

implicit def collections_have_grouping[A, C[A]](ca: C[A])(
  implicit c2i: C[A] => Iterable[A],
           cbf: CanBuildFrom[C[A],C[A],C[C[A]]],
           cbfi: CanBuildFrom[C[A],A,C[A]]
) = {
  new GroupingCollection[A,C[A],C](ca)(c2i, cbf, cbfi)
}

where now we plug in C[A] for C and C[C[A]] for D[C]. Note that we do need the explicit generic types on the call to new GroupingCollection so it can keep straight which types correspond to what. Thanks to the implicit c2i: C[A] => Iterable[A], this automatically handles arrays.

But wait, what if we want to use strings? Now we’re in trouble, because you can’t have a “string of strings”. This is where the extra abstraction helps: we can call D something that’s suitable to hold strings. Let’s pick Vector, and do the following:

val vector_string_builder = (
  new CanBuildFrom[String, String, Vector[String]] {
    def apply() = Vector.newBuilder[String]
    def apply(from: String) = this.apply()
  }
)

implicit def strings_have_grouping(s: String)(
  implicit c2i: String => Iterable[Char],
           cbfi: CanBuildFrom[String,Char,String]
) = {
  new GroupingCollection[Char,String,Vector](s)(
    c2i, vector_string_builder, cbfi
  )
}

We need a new CanBuildFrom to handle the building of a vector of strings (but this is really easy, since we just need to call Vector.newBuilder[String]), and then we need to fill in all the types so that the GroupingCollection is typed sensibly. Note that we already have floating around a [String,Char,String] CanBuildFrom, so strings can be made from collections of chars.

Let’s try it out:

scala> List(true,false,true,true,true).groupedWhile(_ == _)
res1: List[List[Boolean]] = List(List(true), List(false), List(true, true, true))

scala> Array(1,2,5,3,5,6,7,4,1).groupedWhile(_ <= _) 
res2: Array[Array[Int]] = Array(Array(1, 2, 5), Array(3, 5, 6, 7), Array(4), Array(1))

scala> "Hello there!!".groupedWhile(_.isLetter == _.isLetter)
res3: Vector[String] = Vector(Hello,  , there, !!)

Leave a Comment