Stratified sampling in Spark

One possible solution is in Holden’s answer, and here is some other solutions :

Using RDDs :

You can use the sampleByKeyExact transformation, from the PairRDDFunctions class.

sampleByKeyExact(boolean withReplacement, scala.collection.Map fractions, long seed)
Return a subset of this RDD sampled by key (via stratified sampling) containing exactly math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).

And this is how I would do :

Considering the following list :

val seq = Seq(
                (2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),
                (2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),
                (2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)
           )

I would create an RDD Pair, mapping all the users as keys :

val data = sc.parallelize(seq).map(x => (x._1,(x._2,x._3)))

Then I’ll set up fractions for each key as following, since sampleByKeyExact takes a Map of fraction for each key :

val fractions = data.map(_._1).distinct.map(x => (x,0.8)).collectAsMap

What I have done here is mapping on the keys to find distinct keys and then associate each to a fraction equals to 0.8. I collect the whole as a Map.

To sample now :

import org.apache.spark.rdd.PairRDDFunctions
val sampleData = data.sampleByKeyExact(false, fractions, 2L)

or

val sampleData = data.sampleByKeyExact(withReplacement = false, fractions = fractions,seed = 2L)

You can check the count on your keys or data or data sample :

scala > data.count
// [...]
// res10: Long = 12

scala > sampleData.count
// [...]
// res11: Long = 10

Using DataFrames :

Let’s consider the same data (seq) from the previous section.

val df = seq.toDF("keyColumn","value1","value2")
df.show
// +----------+----------+------+
// | keyColumn|    value1|value2|
// +----------+----------+------+
// |2147481832|  23355149|     1|
// |2147481832| 973010692|     1|
// |2147481832|2134870842|     1|
// |2147481832| 541023347|     1|
// |2147481832|1682206630|     1|
// |2147481832|1138211459|     1|
// |2147481832| 852202566|     1|
// |2147481832| 201375938|     1|
// |2147481832| 486538879|     1|
// |2147481832| 919187908|     1|
// | 214748183| 919187908|     1|
// | 214748183|  91187908|     1|
// +----------+----------+------+

We will need the underlying RDD to do that on which we creates tuples of the elements in this RDD by defining our key to be the first column :

val data: RDD[(Int, Row)] = df.rdd.keyBy(_.getInt(0))
val fractions: Map[Int, Double] = data.map(_._1)
                                      .distinct
                                      .map(x => (x, 0.8))
                                      .collectAsMap

val sampleData: RDD[Row] = data.sampleByKeyExact(withReplacement = false, fractions, 2L)
                               .values

val sampleDataDF: DataFrame = spark.createDataFrame(sampleData, df.schema) // you can use sqlContext.createDataFrame(...) instead for spark 1.6)

You can now check the count on your keys or df or data sample :

scala > df.count
// [...]
// res9: Long = 12

scala > sampleDataDF.count
// [...]
// res10: Long = 10

Since Spark 1.5.0 you can use DataFrameStatFunctions.sampleBy method:

df.stat.sampleBy("keyColumn", fractions, seed)

Leave a Comment