How to define and use a User-Defined Aggregate Function in Spark SQL?

Supported methods

Spark >= 3.0

Scala UserDefinedAggregateFunction is being deprecated (SPARK-30423 Deprecate UserDefinedAggregateFunction) in favor of registered Aggregator.

Spark >= 2.3

Vectorized udf (Python only):

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType

from pyspark.sql.types import *
import pandas as pd

df = sc.parallelize([
    ("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])

def below_threshold(threshold, group="group", power="power"):
    @pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
    def below_threshold_(df):
        df = pd.DataFrame(
           df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
        df.reset_index(inplace=True, drop=False)
        return df

    return below_threshold_

Example usage:

df.groupBy("group").apply(below_threshold(-40)).show()

## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## |    b|           true|
## |    a|          false|
## +-----+---------------+

See also Applying UDFs on GroupedData in PySpark (with functioning python example)

Spark >= 2.0 (optionally 1.6 but with slightly different API):

It is possible to use Aggregators on typed Datasets:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

class BelowThreshold[I](f: I => Boolean)  extends Aggregator[I, Boolean, Boolean]
    with Serializable {
  def zero = false
  def reduce(acc: Boolean, x: I) = acc | f(x)
  def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
  def finish(acc: Boolean) = acc

  def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
  def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}

val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)

Spark >= 1.5:

In Spark 1.5 you can create UDAF like this although it is most likely an overkill:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

object belowThreshold extends UserDefinedAggregateFunction {
    // Schema you get as an input
    def inputSchema = new StructType().add("power", IntegerType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("ind", BooleanType)
    // Returned type
    def dataType = BooleanType
    // Self-explaining 
    def deterministic = true
    // zero value
    def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
    // Similar to seqOp in aggregate
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        if (!input.isNullAt(0))
          buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
    }
    // Similar to combOp in aggregate
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))    
    }
    // Called on exit to get return value
    def evaluate(buffer: Row) = buffer.getBoolean(0)
}

Example usage:

df
  .groupBy($"group")
  .agg(belowThreshold($"power").alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark 1.4 workaround:

I am not sure if I correctly understand your requirements but as far as I can tell plain old aggregation should be enough here:

val df = sc.parallelize(Seq(
    ("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")

df
  .withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
  .groupBy($"group")
  .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark <= 1.4:

As far I know, at this moment (Spark 1.4.1), there is no support for UDAF, other than the Hive ones. It should be possible with Spark 1.5 (see SPARK-3947).

Unsupported / internal methods

Internally Spark uses a number of classes including ImperativeAggregates and DeclarativeAggregates.

There are intended for internal usage and may change without further notice, so it is probably not something you want to use in your production code, but just for completeness BelowThreshold with DeclarativeAggregate could be implemented like this (tested with Spark 2.2-SNAPSHOT):

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class BelowThreshold(child: Expression, threshold: Expression) 
    extends  DeclarativeAggregate  {
  override def children: Seq[Expression] = Seq(child, threshold)

  override def nullable: Boolean = false
  override def dataType: DataType = BooleanType

  private lazy val belowThreshold = AttributeReference(
    "belowThreshold", BooleanType, nullable = false
  )()

  // Used to derive schema
  override lazy val aggBufferAttributes = belowThreshold :: Nil

  override lazy val initialValues = Seq(
    Literal(false)
  )

  override lazy val updateExpressions = Seq(Or(
    belowThreshold,
    If(IsNull(child), Literal(false), LessThan(child, threshold))
  ))

  override lazy val mergeExpressions = Seq(
    Or(belowThreshold.left, belowThreshold.right)
  )

  override lazy val evaluateExpression = belowThreshold
  override def defaultResult: Option[Literal] = Option(Literal(false))
} 

It should be further wrapped with an equivalent of withAggregateFunction.

Leave a Comment