Back

Apache Spark - Save Time with Less Join Operations

Some of my surprising findings when optimizing Spark code

Not long ago, I was tasked to improve our Spark application’s runtime performance as some modules takes 7 hours or more to complete on datasets that are not considered very large. This is my conclusion and afterthoughts after countless hours staring at our codebase and YARN application tracking UI. I was able to achieve 30% ~ 80% reduction of runtime, depending on how well the module was written and the nature of operations. But one thing is in common for the modules with maximum achieved runtime reduction, that is refactoring of join operations.

Refactoring Joins

join is a slow operation and everyone knows it, but sometime it’s just unavoidable. So, the real question is, how to prevent unnecessary use of joins. Interestingly, the examples that I discovered in the code base are all related to groupBy operations as well and gist of it can be summarized simply as: use a single groupBy to accomplish as much as possible. Here’s a few (Scala style) pseudo code blocks that I’ve seen drastic runtime improvements by removing the joins

1.

// Original
df.filter(cond1)
  .groupBy(colA)
  .agg(sum(colB).as(colB))
  .join(df.filter(cond2)
          .groupBy(colA)
          .agg(max(colB).as(colC))
  , colA)

// Optimized
df.groupBy(colA)
  .agg(sum(when(cond1, colB)
       .otherwise(sth))
       .as(colB),
       max(when(cond2, colB))
       .as(colC))

Essentially, this code block is trying to get 2 different aggregated statistics of groups from colA based on certain conditions. The join operation then puts the 2 different statistics into the same table column-wise under the same group.

In this case, Spark’s (probably borrowed from SQL) When operator can be substituted for conditional filtering. Unless a default value or column is specified through the otherwise operator, when operator returns null. Spark’s built-in aggregation functions (max, sum, avg, etc.) will automatically ignore null values so when combined with when function, it acts as aggregations with conditions. The code snippet below is much more efficient, especially when more conditional aggregations are used.

2.

// Original
res = df.groupBy(a,b)
        .agg(...as(c))
res.filter(b===1)
   .select(a, c.as(b1))
   .join(
     res.filter(b===2)
     .select(a, c.as(b2))
   , a)

// Optimized
df.groupBy(a)
  .pivot(b, Seq(1,2))
  .agg(first(c))
  .withColumnRenamed(...)

Understanding what this piece of code is trying to do is important. In short, the code tries to group the table based on column A and then pivot on column B. Pivot operation can be considered as a groupBy on a certain column and transpose the result onto columns instead of rows.

With the understanding in mind, we can the proceed to rewrite this piece of code to use pivot operator? Wait, in this piece of code, we are only concerned with a value of 1 or 2 in column B. Worry not, Spark has it all prepared for you. Its pivot operation accepts an optional second parameter (a Sequence of values) and it will create the exact number of columns for every value in the parameter, regardless of the existence of the value in the pivoting column.

Since we are sure that for each unique combination of values in Column A and Column B, there’s only a single row, we can use the first aggregation function.

3.

// Original
impt = df.groupBy(a,b)
         .agg(...as(impt_val))
df.join(impt, Seq(a,b))
  .filter(c > impt_val)
  .select(a,b,d)

// Optimized
df.withColumn(tmp, struct(c,d))
  .groupBy(a,b)
  .agg(...as(impt_val),
      collect_list(tmp).as(tmp))
  .withColumn(tmp, explode(tmp))
  .filter(tmp.c > impt_val)
  .select(a,b,tmp.d)

This piece of code first calculates an aggregated statistic per group and then use it as a threshold to filter the original dataframe. Optimizations to this example might be counter-intuitive, but in reality, if there are a lot of groups as defined by column A and column B, the join operation would take a significant amount of time.

The optimized version uses groupBy to get not only the threshold value, but also a list of values to be used in the filter and select statements later. All values are packed into a single compound column through struct operator and the aggregated list is unpacked through the explode operator, which flattens the list to each row and duplicates every other column. To some extent, the collect_list operation is not extremely efficient, but the overall execution time is still drastically faster.

Afterthought

Perhaps the main culprit here is the shuffling of partitions when executing groupBy and Join statements. As the shuffling might be different for one groupBy operation and one join operation, a third shuffle is required to bring rows of the same key into the same executor. In Spark, partition shuffle might be done over network and hence incurs significant overhead. In contrast, refactoring the operation into a single groupBy operator requires only 1 shuffle and hence runs much faster.

One (actually two) more thing

While staring at the YARN application tracker UI, I have also discovered a few small steps that might impact performance significantly.

Dataframe df.count operation: count seems to be a harmless function but unlike its counterpart in Pandas, this operation requires a full execution of the SQL plan. In short, getting the count of rows in a Spark dataframe takes approximately the same amount of the time as generating the content of the entire dataframe.

Thus, when the dataframe is generated from complex operations, counting from saved parquet or cached content is faster.

Dataframe df.cache operation: caching a large dataset negatively impacts performance, which might stem from the disk I/O overhead due to spilling.

Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy