🚀 环境初始化

这是每个 PySpark 脚本的第一步,启动"指挥部"。

Python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window

# 初始化 SparkSession
# .master("local[*]") 表示使用单机所有 CPU 核心
# .appName(...) 给任务起个名字,方便在 Spark UI 监控
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("MyBigDataProject") \
    .config("spark.sql.shuffle.partitions", "10") \
    .getOrCreate()

print(f"✅ Spark Version: {spark.version}")

🌊 标准 ETL 工作流

大数据开发的核心三部曲:E (Extract 读) → T (Transform 算) → L (Load 存)

📥 第一步:读取数据 (Read)

Python
# 1. 读取 CSV (开发/测试常用)
# header=True: 第一行作为列名
# inferSchema=True: 自动推断类型 (耗时,生产环境慎用)
df_csv = spark.read.csv("path/to/data.csv", header=True, inferSchema=True)

# 2. 读取 Parquet (生产环境标准)
# 速度快,体积小,自带 Schema,无需 inferSchema
df_parquet = spark.read.parquet("path/to/data_folder/")

# 3. 读取 JSON
df_json = spark.read.json("path/to/logs.json")

⚙️ 第二步:数据处理 (Transform)

方式 A:SQL 风格 (推荐用于复杂逻辑/聚合)

Python
# 【关键】注册临时视图,打通 Python 与 SQL
df_csv.createOrReplaceTempView("t_source")

result_sql = spark.sql("""
    SELECT
        City,
        Category,
        SUM(Amount) as Total_Revenue,
        ROUND(AVG(Amount), 2) as Avg_Price
    FROM t_source
    WHERE Amount > 0
    GROUP BY City, Category
    HAVING Total_Revenue > 1000
    ORDER BY Total_Revenue DESC
""")

方式 B:DSL 风格 (推荐用于列操作/清洗)

Python
# 链式调用:增加列、过滤、改名
df_cleaned = df_csv \
    .withColumn("is_high_value", F.when(F.col("Amount") > 500, 1).otherwise(0)) \
    .filter(F.col("City").isNotNull()) \
    .withColumnRenamed("Amount", "Price") \
    .drop("unwanted_column")

📤 第三步:数据落地 (Write)

Python
# 1. 标准落地:Parquet 格式 + 覆盖模式
# mode: "overwrite" (覆盖), "append" (追加), "ignore" (忽略)
result_sql.write.mode("overwrite").parquet("output/report_v1.parquet")

# 2. 【高阶】分区存储 (Partition By)
# 极大提升后续查询速度,必须掌握!
# 结果会生成文件夹结构:output/dt=2024-01-01/city=Shanghai/...
result_sql.write.mode("overwrite") \
    .partitionBy("City") \
    .parquet("output/partitioned_report")

# 3. 存为 CSV (仅用于导出给 Excel 查看,大文件慎用)
# coalesce(1): 强制合并为一个文件 (避免生成几千个小文件)
result_sql.coalesce(1).write.csv("output/report.csv", header=True)

🧩 高频"胶水"代码速查

日常开发中使用频率最高的微操作。

场景 代码 说明
引用列 F.col("price") 比字符串更安全,支持 F.col("a") + F.col("b")
添加常量列 F.lit(100) 不能直接写数字,必须用 lit 包裹
条件判断 F.when(条件, 值).otherwise(值) 类似 Excel 的 IF 函数
空值填充 df.fillna(0, subset=["col1"]) 将 col1 的 null 填为 0
丢弃空值行 df.dropna() 只要有 null 就删掉整行
去重 df.dropDuplicates(["id"]) 按 id 列去重
类型转换 .cast("long") / .cast("string") F.col("age").cast("int")
日期计算 F.date_add(F.col("dt"), -1) 日期减一天
字符串截取 F.substring("col", 1, 3) 取前 3 个字符

⚔️ 高阶武器:窗口函数

面试必考题:分组取 Top N、计算排名、计算移动平均。

Python
# 场景:找出每个城市(City)中,销售额(Amount)最高的 3 个订单
from pyspark.sql.window import Window

# 1. 定义窗口:按 City 分组,按 Amount 降序排
window_spec = Window.partitionBy("City").orderBy(F.col("Amount").desc())

# 2. 计算排名
# row_number(): 1, 2, 3, 4 (不并列)
# dense_rank(): 1, 2, 2, 3 (并列不跳号)
df_ranked = df_csv.withColumn("rank", F.row_number().over(window_spec))

# 3. 过滤 Top 3
top3_df = df_ranked.filter(F.col("rank") <= 3)

⚡ 性能调优

当代码跑得慢时,检查这三点。

Python
# 1. 缓存 (Cache)
# 如果一个 DataFrame 后面被用了 2 次以上,一定要 Cache!
heavy_df = spark.read.parquet("big_data")
heavy_df.cache()
heavy_df.count() # 触发立即缓存

# 2. 重新分区 (Repartition)
# 解决"小文件过多"或"数据倾斜"问题
# 写入前将数据合并为 10 个文件
final_df.repartition(10).write.parquet("output")

# 3. 查看执行计划 (Explain)
# 看看 Spark 到底是怎么执行 Join 的
df.explain()

📜 万能模版脚本

直接复制这个模版,填入你的逻辑即可开始工作。

Python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

def main():
    # 1. Init
    spark = SparkSession.builder.appName("TemplateJob").getOrCreate()

    # 2. Read
    df_raw = spark.read.csv("input_data.csv", header=True, inferSchema=True)

    # 3. Transform
    # 3.1 清洗
    df_clean = df_raw.filter(F.col("id").isNotNull()).dropDuplicates()

    # 3.2 核心逻辑 (SQL)
    df_clean.createOrReplaceTempView("source")
    df_result = spark.sql("""
        SELECT category, COUNT(*) as cnt
        FROM source
        GROUP BY category
    """)

    # 4. Write
    df_result.write.mode("overwrite").parquet("output_result")

    print("✅ Job Finished!")

if __name__ == "__main__":
    main()