我擁有的:
country | sources | infer_from_source
---------------------------------------------------------------------------
null | ["LUX", "CZE","CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"]
"DEU" | ["DEU"] | ["FALSE"]
功能后我想要什么:
country | sources | infer_from_source | inferred_country
------------------------------------------------------------------------------------------------
null | ["LUX", "CZE", "CHN", "FRA"] | ["FALSE", "TRUE", "FALSE", "TRUE"] | ["CZE", "FRA"]
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
我需要創建一個函式
如果country列為空,則sources根據infer_from_source列陣列中的布林值從陣列中提取國家/地區,否則應回傳該country值。
我創建了這個函式
from pyspark.sql.types import BooleanType, IntegerType, StringType, FloatType, ArrayType
import pyspark.sql.functions as F
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
idx = infer_from_source.index("TRUE")
return sources[idx]
return None
但這會產生 - 基本上該.index("TRUE")方法僅回傳與其引數匹配的第一個元素的索引。
country | sources | infer_from_source | inferred_country
--------------------------------------------------------------------
null | ["LUX", "CZE", | ["FALSE", "TRUE", |
| "CHN", "FRA"] | "FALSE", "TRUE"] | "CZE"
"DEU" | ["DEU"] | ["FALSE"] | "DEU"
uj5u.com熱心網友回復:
當您只能使用 Spark 內置函式實作相同的功能時,您應該避免使用 UDF,尤其是在涉及 Pyspark UDF 時。
這是在陣列上使用高階函式transform 的另一種方法filter:
import pyspark.sql.functions as F
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("""filter(
transform(sources, (x, i) -> IF(boolean(infer_from_source[i]), x, null)),
x -> x is not null
)""")
)
)
df1.show()
# ------- -------------------- -------------------- ----------------
#|country| sources| infer_from_source|inferred_country|
# ------- -------------------- -------------------- ----------------
#| null|[LUX, CZE, CHN, FRA]|[FALSE, TRUE, FAL...| [CZE, FRA]|
#| DEU| [DEU]| [FALSE]| [DEU]|
# ------- -------------------- -------------------- ----------------
從 Spark 3 開始,您可以在過濾器 lambda 函式中使用索引:
df1 = df.withColumn(
"inferred_country",
F.when(
F.col("country").isNotNull(),
F.array(F.col("country"))
).otherwise(
F.expr("filter(sources, (x, i) -> boolean(infer_from_source[i]))")
)
)
uj5u.com熱心網友回復:
修復!只是一個串列理解問題
@udf
def determine_entity_country(country: StringType, sources: ArrayType,
infer_from_source: ArrayType) -> ArrayType:
if country:
return country_value
else:
if "TRUE" in infer_from_source:
max_ix = len(infer_from_source)
true_index_array = [x for x in range(0, max_ix) if infer_from_source[x] == "TRUE"]
return [sources[ix] for ix in true_index_array]
return None
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/360766.html
標籤:数组 阿帕奇火花 火花 apache-spark-sql
