How to improve broadcast Join speed with between condition in Spark

The trick is to rewrite join condition so it contains = component which can be used to optimize the query and narrow down possible matches. For numeric values you bucketize your data and use buckets for join condition.

Let’s say your data looks like this:

val a = spark.range(100000)
  .withColumn("cur", (rand(1) * 1000).cast("bigint"))

val b = spark.range(100)
  .withColumn("low", (rand(42) * 1000).cast("bigint"))
  .withColumn("high", ($"low" + rand(-42) * 10).cast("bigint"))

First choose a bucket size appropriate for your data. In this case we can use 50:

val bucketSize = 50L

Assign bucket for each row from a:

val aBucketed = a.withColumn(
  "bucket", ($"cur" / bucketSize).cast("bigint") * bucketSize
)

Create UDF which will emit buckets for a range:

def get_buckets(bucketSize: Long) = 
  udf((low: Long, high: Long) => {
    val min = (low / bucketSize) * bucketSize
    val max = (high / bucketSize) * bucketSize
    (min to max by bucketSize).toSeq
  })

and bucket b:

val bBucketed = b.withColumn(
  "bucket", explode(get_buckets(bucketSize)($"low",  $"high"))
)

use bucket in join condition:

aBucketed.join(
  broadcast(bBucketed), 
  aBucketed("bucket") === bBucketed("bucket") && 
    $"cur" >= $"low" &&  
    $"cur" <= $"high",
  "leftouter"
)

This way Spark will use BroadcastHashJoin:

*BroadcastHashJoin [bucket#184L], [bucket#178L], LeftOuter, BuildRight, ((cur#98L >= low#105L) && (cur#98L <= high#109L))
:- *Project [id#95L, cur#98L, (cast((cast(cur#98L as double) / 50.0) as bigint) * 50) AS bucket#184L]
:  +- *Project [id#95L, cast((rand(1) * 1000.0) as bigint) AS cur#98L]
:     +- *Range (0, 100000, step=1, splits=Some(8))
+- BroadcastExchange HashedRelationBroadcastMode(List(input[3, bigint, false]))
   +- Generate explode(if ((isnull(low#105L) || isnull(high#109L))) null else UDF(low#105L, high#109L)), true, false, [bucket#178L]
      +- *Project [id#102L, low#105L, cast((cast(low#105L as double) + (rand(-42) * 10.0)) as bigint) AS high#109L]
         +- *Project [id#102L, cast((rand(42) * 1000.0) as bigint) AS low#105L]
            +- *Range (0, 100, step=1, splits=Some(8))

instead of BroadcastNestedLoopJoin:

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, LeftOuter, ((cur#98L >= low#105L) && (cur#98L <= high#109L))
:- *Project [id#95L, cast((rand(1) * 1000.0) as bigint) AS cur#98L]
:  +- *Range (0, 100000, step=1, splits=Some(8))
+- BroadcastExchange IdentityBroadcastMode
   +- *Project [id#102L, low#105L, cast((cast(low#105L as double) + (rand(-42) * 10.0)) as bigint) AS high#109L]
      +- *Project [id#102L, cast((rand(42) * 1000.0) as bigint) AS low#105L]
         +- *Range (0, 100, step=1, splits=Some(8))

You can tune bucket size to balance between precision and data size.

If you don’t mind a lower level solution then broadcast a sorted sequence with constant item access (like Array or Vector) and use udf with binary search for joining.

You should also take a look at the number of partitions. 8 partitions for 100GB seems pretty low.

See also:

Leave a Comment