pyspark: Efficiently have partitionBy write to same number of total partitions as original table

You’ve got several options. In my code below I’ll assume you want to write in parquet, but of course you can change that.

(1) df.repartition(numPartitions, *cols).write.partitionBy(*cols).parquet(writePath)

This will first use hash-based partitioning to ensure that a limited number of values from COL make their way into each partition. Depending on the value you choose for numPartitions, some partitions may be empty while others may be crowded with values — for anyone not sure why, read this. Then, when you call partitionBy on the DataFrameWriter, each unique value in each partition will be placed in its own individual file.

Warning: this approach can lead to lopsided partition sizes and lopsided task execution times. This happens when values in your column are associated with many rows (e.g., a city column — the file for New York City might have lots of rows), whereas other values are less numerous (e.g., values for small towns).

(2) df.sort(sortCols).write.parquet(writePath)

This options works great when you want (1) the files you write to be of nearly equal sizes (2) exact control over the number of files written. This approach first globally sorts your data and then finds splits that break up the data into k evenly-sized partitions, where k is specified in the spark config spark.sql.shuffle.partitions. This means that all values with the same values of your sort key are adjacent to each other, but sometimes they’ll span a split, and be in different files. This, if your use-case requires all rows with the same key to be in the same partition, then don’t use this approach.

There are two extra bonuses: (1) by sorting your data its size on disk can often be reduced (e.g., sorting all events by user_id and then by time will lead to lots of repetition in column values, which aids compression) and (2) if you write to a file format the supports it (like Parquet) then subsequent readers can read data in optimally by using predicate push-down, because the parquet writer will write the MAX and MIN values of each column in the metadata, allowing the reader to skip rows if the query specifies values outside of the partition’s (min, max) range.

Note that sorting in Spark is more expensive than just repartitioning and requires an extra stage. Behind the scenes Spark will first determine the splits in one stage, and then shuffle the data into those splits in another stage.

(3) df.rdd.partitionBy(customPartitioner).toDF().write.parquet(writePath)

If you’re using spark on Scala, then you can write a customer partitioner, which can get over the annoying gotchas of the hash-based partitioner. Not an option in pySpark, unfortunately. If you really want to write a custom partitioner in pySpark, I’ve found this is possible, albeit a bit awkward, by using rdd.repartitionAndSortWithinPartitions:

df.rdd \
  .keyBy(sort_key_function) \  # Convert to key-value pairs
  .repartitionAndSortWithinPartitions(numPartitions=N_WRITE_PARTITIONS, 
                                      partitionFunc=part_func) \
  .values() # get rid of keys \
.toDF().write.parquet(writePath)

Maybe someone else knows an easier way to use a custom partitioner on a dataframe in pyspark?

Leave a Comment