Modern Data Pipelines with PySpark: Best Practices & Patterns
Build production-grade data pipelines with PySpark. Learn optimization techniques, design patterns, and best practices from processing petabytes of data at scale.
After four years of running PySpark pipelines on petabyte-scale datasets, the gap between code that works in development and code that runs reliably in production is large. This guide covers the patterns that close that gap.
Session Configuration: Start Right
The SparkSession configuration determines your job's ceiling before you write a single transformation. These settings matter most:
from pyspark.sql import SparkSessionspark = SparkSession.builder \
.appName("production-etl-job") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.sql.shuffle.partitions", "auto") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.sql.parquet.filterPushdown", "true") \
.config("spark.sql.parquet.mergeSchema", "false") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", "2") \
.config("spark.dynamicAllocation.maxExecutors", "20") \
.getOrCreate()
Adaptive Query Execution (AQE): Enabled by default in Spark 3.x but worth confirming. AQE dynamically adjusts join strategies, coalesces shuffle partitions after they're computed, and handles skew joins automatically. It is the single highest-leverage configuration change for most jobs.
mergeSchema = false: Schema inference on read is expensive and dangerous in production. Enforce schemas explicitly.
Partition Management
Partition count is the most consequential tuning decision in a Spark job. Wrong partition count causes:
- Too few: tasks run out of memory, GC pressure, slow execution
- Too many: scheduling overhead, small files, slow writes Rule of thumb: Target 128MB–256MB of data per partition.
def optimal_partition_count(df, target_mb=200):
"""Estimate optimal partition count based on data size."""
# Rough estimate: cache a sample, extrapolate
sample_size = df.limit(10000).count()
# Use Spark's internal size estimator
estimated_bytes = spark._jsparkSession \
.sessionState() \
.executePlan(df._jdf.queryExecution().analyzed()) \
.optimizedPlan() \
.stats() \
.sizeInBytes()
target_bytes = target_mb 1024 1024
return max(1, int(estimated_bytes / target_bytes))
Repartition before heavy transformations
df = df.repartition(optimal_partition_count(df))
Coalesce vs. repartition:
repartition(n): full shuffle, use when increasing partitions or changing partition keycoalesce(n): no shuffle, only merges partitions — use only when decreasing partition countBefore writing: coalesce to reduce output file count
df.coalesce(10).write.format("delta").mode("append").save(output_path)
Join Optimization
Joins are the most common source of performance problems in PySpark.
Broadcast Joins
When one DataFrame is small enough to fit in executor memory, broadcast it to avoid a shuffle join:
from pyspark.sql.functions import broadcast
Explicit broadcast hint — Spark will also do this automatically
if the table is below spark.sql.autoBroadcastJoinThreshold (default: 10MB)
result = large_df.join(broadcast(small_lookup_df), "product_id")
Set the threshold appropriately:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50mb")
Handling Skewed Joins
Data skew — one partition containing disproportionately more data — causes one task to run 10× longer than all others.
Diagnosis:
Check partition sizes after a shuffle
df.groupBy(spark_partition_id()).count().orderBy(desc("count")).show(20)
Solution — salting:
from pyspark.sql.functions import rand, concat, lit, floor
Add random salt to skewed key, explode the smaller table
SALT_FACTOR = 10
Salt the large (skewed) table
salted_large = large_df.withColumn(
"salted_key",
concat(col("user_id"), lit("_"), (floor(rand() * SALT_FACTOR)).cast("string"))
)
Explode the small table to match all salt values
from pyspark.sql.functions import array, explode
salted_small = small_df.withColumn(
"salt_range",
array([lit(str(i)) for i in range(SALT_FACTOR)])
).withColumn("salt", explode("salt_range")) \
.withColumn("salted_key", concat(col("user_id"), lit("_"), col("salt"))) \
.drop("salt_range", "salt")
result = salted_large.join(salted_small, "salted_key").drop("salted_key")
Memory Management
Spark memory is divided into execution memory (shuffles, sorts, joins) and storage memory (caching). OOM errors usually mean one is starving the other.
Adjust memory fraction if caching large DataFrames
spark.conf.set("spark.memory.fraction", "0.75") # Total JVM heap for Spark
spark.conf.set("spark.memory.storageFraction", "0.3") # Within that, for caching
When to cache:
Cache only when a DataFrame is used multiple times in the same job
Using MEMORY_AND_DISK avoids OOM if the cache doesn't fit in memory
from pyspark import StorageLevel
df_lookup.persist(StorageLevel.MEMORY_AND_DISK)
Do multiple operations with df_lookup...
Unpersist when done — don't leak memory
df_lookup.unpersist()
Do not cache DataFrames that are used only once. Caching has a write cost and wastes memory.
Schema Management
Explicit schema definition is mandatory in production. Relying on inference is slow and breaks silently when upstream data changes.
from pyspark.sql.types import (
StructType, StructField, StringType, LongType,
TimestampType, DecimalType, BooleanType
)
ORDER_SCHEMA = StructType([
StructField("order_id", StringType(), nullable=False),
StructField("user_id", StringType(), nullable=False),
StructField("product_id", StringType(), nullable=True),
StructField("amount", DecimalType(18, 2), nullable=False),
StructField("created_at", TimestampType(), nullable=False),
StructField("is_test_order", BooleanType(), nullable=False),
StructField("status", StringType(), nullable=True),
])
df = spark.read.schema(ORDER_SCHEMA).json("s3://raw/orders/")
Fail fast: check schema matches expectation
assert df.schema == ORDER_SCHEMA, f"Schema mismatch: {df.schema}"
Writing Idiomatic PySpark
Common anti-patterns that hurt performance and readability: Anti-pattern: UDFs for simple transformations
Bad: Python UDF breaks Catalyst optimization, causes serialization overhead
from pyspark.sql.functions import udf
@udf("string")
def format_name(first, last):
return f"{first} {last}"
Good: use built-in functions — stays in JVM, Catalyst can optimize
from pyspark.sql.functions import concat_ws
df = df.withColumn("full_name", concat_ws(" ", col("first_name"), col("last_name")))
Anti-pattern: Iterating rows with .collect()
Bad: brings all data to driver, defeats distributed processing
for row in df.collect():
process(row)
Good: push processing to executors with foreachBatch or map partitions
df.foreachPartition(lambda partition: [process(row) for row in partition])
Anti-pattern: Chaining .withColumn() for many columns
Bad: each withColumn creates a new DataFrame plan node — slow with many columns
df = df.withColumn("col1", ...).withColumn("col2", ...).withColumn("col3", ...)
Good: use select with multiple expressions
from pyspark.sql.functions import col, expr
df = df.select(
"*",
expr("...").alias("col1"),
expr("...").alias("col2"),
expr("...").alias("col3"),
)
Production Observability
Instrument your jobs before they reach production:
import logging
from datetime import datetime
logger = logging.getLogger(__name__)
def process_partition(input_path: str, output_path: str, processing_date: str):
start = datetime.utcnow()
df = spark.read.format("delta").load(input_path) \
.filter(col("processing_date") == processing_date)
input_count = df.count()
logger.info(f"Input records: {input_count:,}")
# --- transformations ---
result = transform(df)
# Validate before write
output_count = result.count()
if output_count == 0:
raise ValueError(f"Output is empty for {processing_date} — aborting write")
drop_rate = 1 - (output_count / input_count)
if drop_rate > 0.05:
raise ValueError(f"Drop rate {drop_rate:.1%} exceeds 5% threshold")
result.write.format("delta").mode("overwrite") \
.option("replaceWhere", f"processing_date = '{processing_date}'") \
.save(output_path)
duration = (datetime.utcnow() - start).total_seconds()
logger.info(f"Completed: {output_count:,} records in {duration:.1f}s")
Pre-write validation — asserting that output is non-empty and that the drop rate is within expected bounds — prevents silent data loss. This check takes seconds and has caught real bugs more times than I can count.
The discipline of PySpark at scale is less about advanced API knowledge and more about operational hygiene: explicit schemas, measured partitioning, appropriate caching, and validating before committing writes.
Share
Share on Twitter / X