This is the story of join optimization—the techniques that make the difference between a join that completes in 5 minutes and one that times out after 5 hours. Joins are among the most common and most expensive operations in analytical SQL. They require matching rows from two datasets by key, which almost always involves moving data around the cluster. The strategies range from broadcasting small tables (near-zero extra cost) to pre-bucketing data (eliminating shuffles entirely). Understanding these patterns turns join optimization from guesswork into a systematic process.
When Spark has no better option, it uses SortMergeJoin (SMJ):
Two full shuffles, two full sorts. For a join between a 1TB table and a 500GB table, you’re moving 1.5TB of data across the network and writing 1.5TB to shuffle storage. This is correct and scalable, but expensive.
Every join optimization below is fundamentally about reducing or eliminating these two shuffles.
If one side of the join is small (a dimension table, a lookup table, a filtered result set), broadcast it to every executor and let each executor perform a local hash join with its portion of the large side.
from pyspark.sql.functions import broadcast
result = large_orders.join(broadcast(small_countries), "country_code")
Or let Spark auto-broadcast via spark.sql.autoBroadcastJoinThreshold (default 10MB):
spark.sql.autoBroadcastJoinThreshold = 50MB # broadcast tables up to 50MB
Why it’s fast: the large table never shuffles. The small table is collected to the driver (one shuffle, tiny), broadcast to all executors (a tree broadcast), and each executor builds a hash map from it. The large side is read locally and each row is probed in the in-memory hash map. Zero network for the large side.
Broadcast hash join is like printing a pocket reference card for everyone. Instead of sending every customer (large side row) to a central office to look up their country info (small side), you print copies of the country lookup table and hand one to each customer service agent (executor). Each agent answers questions locally without sending anyone anywhere.
When to use:
When NOT to use:
A ShuffledHashJoin shuffles both sides (so matching keys go to the same partition) but then builds an in-memory hash table from the smaller side within each partition, rather than sorting and merging.
This is faster than SortMergeJoin if:
Spark may choose this automatically for certain size ratios, or it can be forced:
spark.sql.join.preferSortMergeJoin = false
ShuffledHashJoin is like organizing a speed-dating event where you shuffle everyone by zip code first. People within the same zip code (partition) then do local matching—one group (smaller side) sits at tables (hash map), and the other group (larger side) walks through. Efficient if the sitting group fits at the tables; chaotic if there are too many people for the available tables (spill).
Bucketing is a write-time optimization that pre-distributes a table’s data by a key, storing it in a fixed number of buckets (files) where all rows with the same hash(key) % num_buckets value are in the same bucket.
orders_df.write.bucketBy(100, "customer_id").sortBy("customer_id") \
.saveAsTable("orders_bucketed")
customers_df.write.bucketBy(100, "customer_id").sortBy("customer_id") \
.saveAsTable("customers_bucketed")
When you join two tables that are bucketed on the same column with the same number of buckets, no shuffle is needed: Spark reads bucket N from the left side and bucket N from the right side on the same executor, and merges them directly. The shuffle is eliminated entirely at query time.
Bucketing is like pre-sorting mail into neighborhood bins at the post office. When mail arrives, instead of sorting it during delivery (shuffle at query time), the post office sorts it into 100 neighborhood bins during intake. When delivery day comes, the deliverer for neighborhood 5 picks up bin 5 from both the incoming letters pile and the incoming packages pile—no sorting needed at delivery time. The work was done once at write time; every future read benefits.
When to use:
customer_id is the universal join key in a star schema)Limitations:
With Adaptive Query Execution (AQE) enabled (spark.sql.adaptive.enabled=true), Spark can convert a SortMergeJoin to a BroadcastHashJoin at runtime, based on actual data sizes observed after the map phase of the shuffle.
This is useful when the query planner estimated one side to be larger than autoBroadcastJoinThreshold (so it planned a SMJ), but after filtering and aggregation, the actual data at the join point is small enough to broadcast.
-- Initially planned as SortMergeJoin (planner didn't know filtered result would be small)
-- AQE detects: "left side after filtering is only 8MB"
-- Converts to BroadcastHashJoin at runtime—no restart needed
AQE join conversion is like a taxi dispatcher who switches you from a shared van to a private car mid-route. You booked a shared van (SMJ) because the dispatcher didn’t know how many other passengers there’d be. Halfway through the journey, they see that only 2 passengers remain and upgrade you to a private car (BHJ). The original plan wasn’t wrong; it just didn’t have enough information yet.
Data skew in a join occurs when a small number of keys have disproportionately many rows. The partition containing the skewed key takes much longer than all other partitions—even with thousands of executor cores, one task holding up the entire stage.
Common in:
Approach 1: AQE skew join handling (Spark 3.0+): automatically detects skewed partitions (tasks that are 5× the median size) and splits them into sub-partitions, distributing the work across multiple tasks. The skewed key’s data is read multiple times (once per sub-partition), but this is correct because the join is an equi-join.
Approach 2: Salting (manual): add a random salt suffix to the skewed key before joining, and explode the small side’s matching keys to include all salt values:
# On the large side (skewed): add random salt 0-9 to the join key
from pyspark.sql.functions import concat_ws, lit, floor, rand
large = large_df.withColumn("salted_key",
concat_ws("_", col("customer_id"), (floor(rand() * 10)).cast("string")))
# On the small side: explode each key to all 10 salt values
from pyspark.sql.functions import explode, array
small = small_df.withColumn("salt", explode(array([lit(i) for i in range(10)])))
small = small.withColumn("salted_key",
concat_ws("_", col("customer_id"), col("salt").cast("string")))
result = large.join(small, "salted_key")
Salting distributes the skewed key’s rows across 10 partitions (salt 0 through 9), each handled by a separate task. The small side is replicated 10× to match, which is acceptable if the small side is small.
Before the join, apply filters that reduce the number of partitions read from both sides. For Parquet/Delta tables, partition column filters eliminate entire partition directories from the scan.
SELECT * FROM orders JOIN customers ON orders.customer_id = customers.id
WHERE orders.order_date >= '2024-01-01'
If orders is partitioned by order_date, this query only reads partitions from 2024 onward—potentially 50% of the data. The join input is smaller; the shuffle is smaller.
With Dynamic Partition Pruning (DPP) enabled (AQE), Spark evaluates the dimension table filter, builds a list of matching keys, and uses that as a runtime filter on the fact table scan—even when the fact table isn’t directly filtered. This is particularly powerful in star schema joins.
Is one side small enough to broadcast?
→ Yes: use Broadcast Hash Join (fastest; zero shuffle on the large side)
→ No: continue
Are both tables frequently joined on the same key?
→ Yes: bucket both tables on the join key → use bucketed join (zero shuffle at query time)
→ No: continue
Is AQE enabled? (it should be, in Spark 3.0+)
→ Yes: AQE may auto-convert to BHJ and handle skew
→ No: enable AQE
Is there skew in the join key?
→ Yes: use AQE skew join handling or manual salting
→ No: tuned SortMergeJoin with appropriate partition count
Can you reduce join input with filters or partition pruning?
→ Always try: filter before joining, use partitioned tables, push filters down
Join optimization is about eliminating or minimizing shuffles and ensuring each task has a manageable amount of data. The hierarchy of strategies from most to least expensive (in terms of shuffle avoidance): Broadcast Hash Join (no shuffle on the large side), bucketed join (no shuffle at query time, shuffle at write time), AQE runtime conversion (automatic BHJ when actual sizes allow), ShuffledHashJoin (both sides shuffled, but no sort), and SortMergeJoin (both sides shuffled and sorted—correct for any size, but most expensive). Skew handling (AQE or salting) is orthogonal: it prevents one task from monopolizing a stage regardless of join strategy. Partition pruning and filter pushdown reduce the input size before the join runs. So the story of join optimization is: profile the join sizes → choose the strategy that eliminates the most shuffles → handle skew → push down filters → verify with EXPLAIN → measure and iterate.