Serialize a custom transformer using python to be used within a Pyspark ML pipeline

As of Spark 2.3.0 there’s a much, much better way to do this.

Simply extend DefaultParamsWritable and DefaultParamsReadable and your class will automatically have write and read methods that will save your params and will be used by the PipelineModel serialization system.

The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.

  • PipelineModel.read instantiates a PipelineModelReader
  • PipelineModelReader loads metadata and checks if language is 'Python'. If it’s not, then the typical JavaMLReader is used (what most of these answers are designed for)
  • Otherwise, PipelineSharedReadWrite is used, which calls DefaultParamsReader.loadParamsInstance

loadParamsInstance will find class from the saved metadata. It will instantiate that class and call .load(path) on it. You can extend DefaultParamsReader and get the DefaultParamsReader.load method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load method as a starting place.

On the opposite side:

  • PipelineModel.write will check if all stages are Java (implement JavaMLWritable). If so, the typical JavaMLWriter is used (what most of these answers are designed for)
  • Otherwise, PipelineWriter is used, which checks that all stages implement MLWritable and calls PipelineSharedReadWrite.saveImpl
  • PipelineSharedReadWrite.saveImpl will call .write().save(path) on each stage.

You can extend DefaultParamsWriter to get the DefaultParamsWritable.write method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter as a starting point.

Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform

class SetValueTransformer(
    Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
    value = Param(
        Params._dummy(),
        "value",
        "value to fill",
    )

    @keyword_only
    def __init__(self, outputCols=None, value=0.0):
        super(SetValueTransformer, self).__init__()
        self._setDefault(value=0.0)
        kwargs = self._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, outputCols=None, value=0.0):
        """
        setParams(self, outputCols=None, value=0.0)
        Sets params for this SetValueTransformer.
        """
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setValue(self, value):
        """
        Sets the value of :py:attr:`value`.
        """
        return self._set(value=value)

    def getValue(self):
        """
        Gets the value of :py:attr:`value` or its default value.
        """
        return self.getOrDefault(self.value)

    def _transform(self, dataset):
        for col in self.getOutputCols():
            dataset = dataset.withColumn(col, lit(self.getValue()))
        return dataset

Now we can use it:

from pyspark.ml import Pipeline, PipelineModel

svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)

p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()

Result:

+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

matches? True
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

Leave a Comment