PySpark DataFrame 添加自增 ID


PySpark DataFrame 添加自增 ID

本文原始地址:https://sitoi.cn/posts/62634.html

在用 Spark 處理數據的時候,經常需要給全量數據增加一列自增 ID 序號,在存入數據庫的時候,自增 ID 也常常是一個很關鍵的要素。
在 DataFrame 的 API 中沒有實現這一功能,所以只能通過其他方式實現,或者轉成 RDD 再用 RDD 的 zipWithIndex 算子實現。
下面呢就介紹三種實現方式。

創建 DataFrame 對象

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
    [
        {"name": "Alice", "age": 18},
        {"name": "Sitoi", "age": 22},
        {"name": "Shitao", "age": 22},
        {"name": "Tom", "age": 7},
        {"name": "De", "age": 17},
        {"name": "Apple", "age": 45}
    ]
)
df.show()

輸出:

+---+------+
|age|  name|
+---+------+
| 18| Alice|
| 22| Sitoi|
| 22|Shitao|
|  7|   Tom|
| 17|    De|
| 45| Apple|
+---+------+

方式一:monotonically_increasing_id()

使用自帶函數 monotonically_increasing_id() 創建,由於 spark 會有分區,所以生成的 ID 保證單調增加且唯一,但不是連續的

優點:對於沒有分區的文件,處理速度快。
缺點:由於 spark 的分區,會導致,ID 不是連續增加。

df = df.withColumn("id", monotonically_increasing_id())
df.show()

輸出:

+---+------+-----------+
|age|  name|         id|
+---+------+-----------+
| 18| Alice| 8589934592|
| 22| Sitoi|17179869184|
| 22|Shitao|25769803776|
|  7|   Tom|42949672960|
| 17|    De|51539607552|
| 45| Apple|60129542144|
+---+------+-----------+

如果讀取本地的單個 CSV 文件 或 JSON 文件,ID 會是連續增加且唯一的。

方法二:窗口函數

利用窗口函數:設置窗口函數的分區以及排序,因為是全局排序而不是分組排序,所有分區依據為空,排序規則沒有特殊要求也可以隨意填寫

優點:保證 ID 連續增加且唯一
缺點:運行速度滿,並且數據量過大會爆內存,需要排序,會改變原始數據順序。

from pyspark.sql.functions import row_number

spec = Window.partitionBy().orderBy("age")
df = df.withColumn("id", row_number().over(spec))
df.show()

輸出:

+---+------+---+
|age|  name| id|
+---+------+---+
|  7|   Tom|  1|
| 17|    De|  2|
| 18| Alice|  3|
| 22| Sitoi|  4|
| 22|Shitao|  5|
| 45| Apple|  6|
+---+------+---+

方法三:RDD 的 zipWithIndex 算子

轉成 RDD 再用 RDD 的 zipWithIndex 算子實現

優點:保證 ID 連續 增加且唯一。
缺點:運行速度慢。

from pyspark.sql import SparkSession
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.types import StructField, LongType

spark = SparkSession.builder.getOrCreate()

schema = df.schema.add(StructField("id", LongType()))
rdd = df.rdd.zipWithIndex()


def flat(l):
    for k in l:
        if not isinstance(k, (list, tuple)):
            yield k
        else:
            yield from flat(k)


rdd = rdd.map(lambda x: list(flat(x)))
df = spark.createDataFrame(rdd, schema)
df.show()

輸出:

+---+------+---+
|age|  name| id|
+---+------+---+
| 18| Alice|  0|
| 22| Sitoi|  1|
| 22|Shitao|  2|
|  7|   Tom|  3|
| 17|    De|  4|
| 45| Apple|  5|
+---+------+---+


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM