我有一個如下所示的 spark 資料框。
| 日期 | ID | 視窗大小 | 數量 |
|---|---|---|---|
| 01/01/2020 | 1 | 2 | 1 |
| 2020 年 2 月 1 日 | 1 | 2 | 2 |
| 03/01/2020 | 1 | 2 | 3 |
| 2020 年 4 月 1 日 | 1 | 2 | 4 |
| 01/01/2020 | 2 | 3 | 1 |
| 2020 年 2 月 1 日 | 2 | 3 | 2 |
| 03/01/2020 | 2 | 3 | 3 |
| 2020 年 4 月 1 日 | 2 | 3 | 4 |
我正在嘗試將大小為 window_size 的滾動視窗應用于資料框中的每個 ID 并獲取滾動總和。基本上我正在計算一個滾動總和(pd.groupby.rolling(window=n).sum()在熊貓中),其中每個組的視窗大小(n)可以改變。
預期產出
| 日期 | ID | 視窗大小 | 數量 | 滾動總和 |
|---|---|---|---|---|
| 01/01/2020 | 1 | 2 | 1 | 空值 |
| 2020 年 2 月 1 日 | 1 | 2 | 2 | 3 |
| 03/01/2020 | 1 | 2 | 3 | 5 |
| 2020 年 4 月 1 日 | 1 | 2 | 4 | 7 |
| 01/01/2020 | 2 | 3 | 1 | 空值 |
| 2020 年 2 月 1 日 | 2 | 3 | 2 | 空值 |
| 03/01/2020 | 2 | 3 | 3 | 6 |
| 2020 年 4 月 1 日 | 2 | 3 | 4 | 9 |
我正在努力尋找一個在大型資料幀( - 350M 行)上有效且足夠快的解決方案。
我試過的
我在下面的執行緒中嘗試了解決方案:
這個想法是首先使用sf.collect_list然后ArrayType正確地對列進行切片。
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
但是,這會產生以下錯誤:
TypeError:列不可迭代
我也嘗試過使用sf.expr如下
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
其中產生:
data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;
I tried manually casting the qty_list column to ArrayType(IntegerType()) with the same result.
I tried using a UDF but that fails with several out of memory errors after 1,5 hours or so.
Questions
Reading the spark documentation suggests to me that I should be able to pass columns to
sf.slice(), am I doing something wrong? Where is theTypeErrorcoming from?Is there a better way to achieve what I want without using
sf.collect_list()and/orsf.slice()?If all else fails, what would be the optimal way to do this using a udf? I attempted different versions of the same udf and tried to make sure the udf is the last operation spark has to perform, but all failed.
uj5u.com熱心網友回復:
關于你得到的錯誤:
- 第一個意味著您不能將列傳遞給
slice使用 DataFrame API 函式(除非您有 Spark 3.1 )。但是當您嘗試在 SQL 運算式中使用它時,您已經得到了它。 - 發生第二個錯誤是因為您傳遞了
expr.slice(qty_list, count, window_size)否則 Spark應該將它們視為字串,因此會出現錯誤訊息。
也就是說,您幾乎明白了,您需要更改切片運算式以獲得正確的陣列大小,然后使用aggregate函式對結果陣列的值求和。試試這個:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size 1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc x)").cast("int")
).drop("qty_list", "rn")
output.show()
# ---------- --- ----------- --- -----------
#| date| ID|window_size|qty|rolling_sum|
# ---------- --- ----------- --- -----------
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
# ---------- --- ----------- --- -----------
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/405144.html
標籤:
