🚀 环境初始化
这是每个 PySpark 脚本的第一步,启动"指挥部"。
Python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window
# 初始化 SparkSession
# .master("local[*]") 表示使用单机所有 CPU 核心
# .appName(...) 给任务起个名字,方便在 Spark UI 监控
spark = SparkSession.builder \
.master("local[*]") \
.appName("MyBigDataProject") \
.config("spark.sql.shuffle.partitions", "10") \
.getOrCreate()
print(f"✅ Spark Version: {spark.version}")
🌊 标准 ETL 工作流
大数据开发的核心三部曲:E (Extract 读) → T (Transform 算) → L (Load 存)
📥 第一步:读取数据 (Read)
Python
# 1. 读取 CSV (开发/测试常用)
# header=True: 第一行作为列名
# inferSchema=True: 自动推断类型 (耗时,生产环境慎用)
df_csv = spark.read.csv("path/to/data.csv", header=True, inferSchema=True)
# 2. 读取 Parquet (生产环境标准)
# 速度快,体积小,自带 Schema,无需 inferSchema
df_parquet = spark.read.parquet("path/to/data_folder/")
# 3. 读取 JSON
df_json = spark.read.json("path/to/logs.json")
⚙️ 第二步:数据处理 (Transform)
方式 A:SQL 风格 (推荐用于复杂逻辑/聚合)
Python
# 【关键】注册临时视图,打通 Python 与 SQL
df_csv.createOrReplaceTempView("t_source")
result_sql = spark.sql("""
SELECT
City,
Category,
SUM(Amount) as Total_Revenue,
ROUND(AVG(Amount), 2) as Avg_Price
FROM t_source
WHERE Amount > 0
GROUP BY City, Category
HAVING Total_Revenue > 1000
ORDER BY Total_Revenue DESC
""")
方式 B:DSL 风格 (推荐用于列操作/清洗)
Python
# 链式调用:增加列、过滤、改名
df_cleaned = df_csv \
.withColumn("is_high_value", F.when(F.col("Amount") > 500, 1).otherwise(0)) \
.filter(F.col("City").isNotNull()) \
.withColumnRenamed("Amount", "Price") \
.drop("unwanted_column")
📤 第三步:数据落地 (Write)
Python
# 1. 标准落地:Parquet 格式 + 覆盖模式
# mode: "overwrite" (覆盖), "append" (追加), "ignore" (忽略)
result_sql.write.mode("overwrite").parquet("output/report_v1.parquet")
# 2. 【高阶】分区存储 (Partition By)
# 极大提升后续查询速度,必须掌握!
# 结果会生成文件夹结构:output/dt=2024-01-01/city=Shanghai/...
result_sql.write.mode("overwrite") \
.partitionBy("City") \
.parquet("output/partitioned_report")
# 3. 存为 CSV (仅用于导出给 Excel 查看,大文件慎用)
# coalesce(1): 强制合并为一个文件 (避免生成几千个小文件)
result_sql.coalesce(1).write.csv("output/report.csv", header=True)
🧩 高频"胶水"代码速查
日常开发中使用频率最高的微操作。
| 场景 | 代码 | 说明 |
|---|---|---|
| 引用列 | F.col("price") |
比字符串更安全,支持 F.col("a") + F.col("b") |
| 添加常量列 | F.lit(100) |
不能直接写数字,必须用 lit 包裹 |
| 条件判断 | F.when(条件, 值).otherwise(值) |
类似 Excel 的 IF 函数 |
| 空值填充 | df.fillna(0, subset=["col1"]) |
将 col1 的 null 填为 0 |
| 丢弃空值行 | df.dropna() |
只要有 null 就删掉整行 |
| 去重 | df.dropDuplicates(["id"]) |
按 id 列去重 |
| 类型转换 | .cast("long") / .cast("string") |
F.col("age").cast("int") |
| 日期计算 | F.date_add(F.col("dt"), -1) |
日期减一天 |
| 字符串截取 | F.substring("col", 1, 3) |
取前 3 个字符 |
⚔️ 高阶武器:窗口函数
面试必考题:分组取 Top N、计算排名、计算移动平均。
Python
# 场景:找出每个城市(City)中,销售额(Amount)最高的 3 个订单
from pyspark.sql.window import Window
# 1. 定义窗口:按 City 分组,按 Amount 降序排
window_spec = Window.partitionBy("City").orderBy(F.col("Amount").desc())
# 2. 计算排名
# row_number(): 1, 2, 3, 4 (不并列)
# dense_rank(): 1, 2, 2, 3 (并列不跳号)
df_ranked = df_csv.withColumn("rank", F.row_number().over(window_spec))
# 3. 过滤 Top 3
top3_df = df_ranked.filter(F.col("rank") <= 3)
⚡ 性能调优
当代码跑得慢时,检查这三点。
Python
# 1. 缓存 (Cache)
# 如果一个 DataFrame 后面被用了 2 次以上,一定要 Cache!
heavy_df = spark.read.parquet("big_data")
heavy_df.cache()
heavy_df.count() # 触发立即缓存
# 2. 重新分区 (Repartition)
# 解决"小文件过多"或"数据倾斜"问题
# 写入前将数据合并为 10 个文件
final_df.repartition(10).write.parquet("output")
# 3. 查看执行计划 (Explain)
# 看看 Spark 到底是怎么执行 Join 的
df.explain()
📜 万能模版脚本
直接复制这个模版,填入你的逻辑即可开始工作。
Python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
def main():
# 1. Init
spark = SparkSession.builder.appName("TemplateJob").getOrCreate()
# 2. Read
df_raw = spark.read.csv("input_data.csv", header=True, inferSchema=True)
# 3. Transform
# 3.1 清洗
df_clean = df_raw.filter(F.col("id").isNotNull()).dropDuplicates()
# 3.2 核心逻辑 (SQL)
df_clean.createOrReplaceTempView("source")
df_result = spark.sql("""
SELECT category, COUNT(*) as cnt
FROM source
GROUP BY category
""")
# 4. Write
df_result.write.mode("overwrite").parquet("output_result")
print("✅ Job Finished!")
if __name__ == "__main__":
main()