PySparkで特定の列の出現回数をカウントしてmap型でまとめる

+---+----+
| id|type|
+---+----+
|  1|   a|
|  1|   b|
|  1|   b|
|  2|   c|
|  2|   c|
|  2|   c|
+---+----+

のデータを元にこの様に変換してみます(map型の場合parquet形式で保存が出来ないので他の形式で保存する必要があります)

+---+----------------+
| id|           count|
+---+----------------+
|  1|[a -> 1, b -> 2]|
|  2|        [c -> 3]|
+---+----------------+

コード

from pyspark.sql.functions import collect_list, map_from_entries, struct

df = spark.createDataFrame([
        (1, "a"),
        (1, "b"),
        (1, "b"),
        (2, "c"),
        (2, "c"),
        (2, "c")
    ],
    ["id", "type"])

df = df.groupby("id", "type").count()
df = df.groupby("id").agg(map_from_entries(collect_list(struct("type","count"))).alias("count"))

処理過程

データ

処理前のデータです。

+---+----+
| id|type|
+---+----+
|  1|   a|
|  1|   b|
|  1|   b|
|  2|   c|
|  2|   c|
|  2|   c|
+---+----+

出現回数のカウント

df.groupby("id", "type").count()

まず集計関数を使って出現回数をカウントしたDataFrameを作成します。

+---+----+-----+
| id|type|count|
+---+----+-----+
|  2|   c|    3|
|  1|   a|    1|
|  1|   b|    2|
+---+----+-----+

一度2次元配列にする

カウントしたデータを集計関数を使ってハッシュ形式に一気に変換するのですが、飛びすぎて分かりにくいので一旦この様に

df.groupby("id").agg(collect_list(struct("type","count")))

途中結果を見てみます。

+---+---------------------------------+
| id|collect_list(struct(type, count))|
+---+---------------------------------+
|  1|                 [[a, 1], [b, 2]]|
|  2|                         [[c, 3]]|
+---+---------------------------------+

2次元配列に変換します。

二次元配列をmapに変換する

df.groupby("id").agg(map_from_entries(collect_list(struct("type","count"))).alias("count"))

二次元配列をmap_from_entriesを使ってmapに変換してcount列に上書きします。

+---+----------------+
| id|           count|
+---+----------------+
|  1|[a -> 1, b -> 2]|
|  2|        [c -> 3]|
+---+----------------+

JSON形式で出力

to_jsonをimportした上で列内のmapデータをjson形式にしたい場合↓の様にやると変換できます。

df.withColumn("count", to_json("count"))
+---+-------------+
| id|        count|
+---+-------------+
|  1|{"a":1,"b":2}|
|  2|      {"c":3}|
+---+-------------+

参照