PySparkでgroupbyで集計したデータを配列にして一行にまとめる

重複をそのまま配列にする

+---+----+
| id|type|
+---+----+
|  1|   a|
|  1|   b|
|  1|   b|
|  2|   c|
|  2|   c|
|  2|   c|
+---+----+

このデータに対してidで集計して集計されたデータを一行にまとめてみます。

from pyspark.sql.functions import collect_list

df = spark.createDataFrame([
        (1, "a"),
        (1, "b"),
        (1, "b"),
        (2, "c"),
        (2, "c"),
        (2, "c")
    ],
    ["id", "type"])

df = df.groupby("id").agg(collect_list("type").alias("types"))
df.show()

.agg(collect_list("type"))とすることで指定の列の集計を配列にしてまとめてくれます。

+---+---------+
| id|    types|
+---+---------+
|  1|[a, b, b]|
|  2|[c, c, c]|
+---+---------+

重複を削除して配列にする

collect_setメソッドを使用すると

from pyspark.sql.functions import collect_set

df = spark.createDataFrame([
        (1, "a"),
        (1, "b"),
        (1, "b"),
        (2, "c"),
        (2, "c"),
        (2, "c")
    ],
    ["id", "type"])

df = df.groupby("id").agg(collect_set("type").alias("types"))
df.show()

この様に重複を削除した上で配列にしてくれます。

+---+------+
| id| types|
+---+------+
|  1|[b, a]|
|  2|   [c]|
+---+------+

参照