我正在使用 Spark 處理包含 2000 萬個 XML 檔案的資料集。我最初正在處理所有這些,但實際上我只需要大約三分之一。在不同的 spark 作業流中,我創建了一個資料框keyfilter,其中一列是每個 XML 的鍵,第二列是布林值,True如果對應于鍵的 xml 應該被處理,False否則。
XML 本身是使用 Pandas UDF 處理的,我無法共享。
我在 DataBricks 上的筆記本基本上是這樣作業的:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()
def process_part(part, fraction=1, filter=True, return_df=False):
try:
df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-d*' % (DATE, part))
# Sometimes, the file part-xxxxx doesn't exist
except AnalysisException:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/part-d_%s_%d' % (part, DATE, time.time()))
if return_df:
return mod_df
n_cores = 6
i=0
while n_cores*i < 1024:
with ThreadPool(n_cores) as p:
p.map(process_part, range(n_cores*i, min(1024, n_cores*i n_cores)))
i = 1
The reason I'm posting this question is that despite the fact that the Pandas UDF should be the most expensive operation taking place, adding the filtering actually makes my code run much slower than if I weren't filtering at all. I am very new to Spark and I'm wondering if I'm doing something stupid here that is causing the joins with keyfilter to be very slow, and if so, if there is a way to make them fast (e.g., is there a way to make keyfilter act like a hash table from keys to booleans, like CREATE INDEX in SQL?). I imagine that the large size of keyfilter is playing some kind of role here; it has 20 million rows while df in process_part has only a tiny fraction of those rows (df但是,它的大小要大得多,因為它包含 XML 檔案)。我是否應該將所有部分組合成一個巨大的資料框,而不是一次處理一個?
或者有沒有辦法通知 Spark 密鑰在兩個資料幀中都是唯一的?
uj5u.com熱心網友回復:
在合理的時間范圍內進行連接的關鍵是使用broadcastonkeyfilter進行廣播哈希連接,而不是標準連接。我還合并了一些部分并降低了并行度(由于某種原因,執行緒過多似乎有時會導致引擎崩潰)。我的新性能代碼如下所示:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()
def process_parts(part_pair, fraction=1, return_df=False, filter=True):
dfs = []
parts_start, parts_end = part_pair
parts = range(parts_start, parts_end)
for part in parts:
try:
df = spark.read.parquet('/input/path/on/s3/%s/part-d*' % (DATE, part))
dfs.append(df)
except AnalysisException:
print("There is no part d!" % part)
continue
if len(dfs) >= 2:
df = reduce(lambda x, y: x.union(y), dfs)
elif len(dfs) == 1:
df = dfs[0]
else:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/parts-d-d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
if return_df:
return mod_df
start_time = time.time()
pairs = [(i*4, i*4 4) for i in range(256)]
with ThreadPool(3) as p:
batch_start_time = time.time()
for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
batch_end_time = time.time()
batch_len = batch_end_time - batch_start_time
cum_len = batch_end_time - start_time
print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i 1, batch_len // 60, batch_len % 60))
print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
batch_start_time = time.time()
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/447337.html
