from pyspark.sql import functions as F
from pyspark.sql import Window
w = Window.partitionBy('id').orderBy('date')
sorted_list_df = input_df.withColumn(
'sorted_list', F.collect_list('value').over(w)
)\
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
Window
examples provided by users often don’t really explain what is going on so let me dissect it for you.
As you know, using collect_list
together with groupBy
will result in an unordered list of values. This is because depending on how your data is partitioned, Spark will append values to your list as soon as it finds a row in the group. The order then depends on how Spark plans your aggregation over the executors.
A Window
function allows you to control that situation, grouping rows by a certain value so you can perform an operation over
each of the resultant groups:
w = Window.partitionBy('id').orderBy('date')
partitionBy
– you want groups/partitions of rows with the sameid
orderBy
– you want each row in the group to be sorted bydate
Once you have defined the scope of your Window – “rows with the same id
, sorted by date
” -, you can use it to perform an operation over it, in this case, a collect_list
:
F.collect_list('value').over(w)
At this point you created a new column sorted_list
with an ordered list of values, sorted by date, but you still have duplicated rows per id
. To trim out the duplicated rows you want to groupBy
id
and keep the max
value in for each group:
.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))