jax.numpy.split可用于將陣列分割成等長的段,余數在最后一個元素中。例如,將 5000 個元素的陣列拆分為 10 個段:
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
segments = jnp.split(array, split_indices)
這需要大約 10 秒才能在 Google Colab 和我的本地機器上執行。 對于小型陣列上的如此簡單的任務,這似乎是不合理的。我做錯了什么讓這變慢了嗎?
更多細節(JIT 快取,也許?)
提供相同形狀和相同拆分索引的陣列,后續呼叫.split非常快。例如,以下回圈的第一次迭代非常慢,但所有其他回圈都很快。(11 秒對 40 毫秒)
from timeit import default_timer as timer
import jax.numpy as jnp
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
for k in range(5):
start = timer()
segments = jnp.split(array, split_indices)
end = timer()
print(f'call {k}: {end - start:0.2f} s')
輸出:
call 0: 11.79 s
call 1: 0.04 s
call 2: 0.04 s
call 3: 0.05 s
call 4: 0.04 s
我假設后續呼叫更快,因為 JAX 正在快取split每個引陣列合的 jit 版本。如果是這種情況,那么我認為split由于編譯開銷而速度很慢(在第一次這樣的呼叫中)。
真的嗎?如果是,我應該如何拆分 JAX 陣列而不影響性能?
uj5u.com熱心網友回復:
這很慢,因為在 的實作中存在權衡split(),并且您的函式恰好處于權衡的錯誤方面。
在 XLA 中有幾種計算切片的方法,包括XLA:Slice (ie lax.slice)、XLA:DynamicSlice (ie lax.dynamic_slice) 和XLA:Gather (ie lax.gather)。
這些之間的主要區別在于開始和結束索引是靜態的還是動態的。靜態索引本質上意味著您要專門針對特定索引值進行計算:這會在第一次呼叫時產生一些小的編譯開銷,但后續呼叫可能會非常快。另一方面,動態索引不包括這種專門化,因此編譯開銷較小,但每次執行所需的時間稍長。你也許能猜到這是怎么回事……
jnp.split目前是根據lax.slice(參見代碼)實作的,這意味著它使用靜態索引。這意味著第一次使用jnp.split會產生與輸出數量成正比的編譯成本,但重復呼叫會很快執行。這似乎split是生成少數陣列的常見用途的最佳方法。
在您的情況下,您正在生成數百個陣列,因此編譯成本遠遠超過執行。
為了說明這一點,以下是基于 、 和 的三種相同陣列拆分方法的一些gather時間slice安排dynamic_slice。jnp.split如果您的程式受益于不同的實作,您可能希望直接使用其中之一,而不是使用:
from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax
def f_slice(x, step=10):
return [lax.slice(x, (N,), (N step,)) for N in range(0, x.shape[0], step)]
def f_dynamic_slice(x, step=10):
return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]
def f_gather(x, step=10):
step = jnp.asarray(step)
return [x[N: N step] for N in range(0, x.shape[0], step)]
def time(f, x):
print(f.__name__)
for k in range(5):
start = timer()
segments = jax.block_until_ready(f(x))
end = timer()
print(f' call {k}: {end - start:0.2f} s')
x = jnp.ones(5000)
time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)
這是 Colab CPU 運行時的輸出:
f_slice
call 0: 7.78 s
call 1: 0.05 s
call 2: 0.04 s
call 3: 0.04 s
call 4: 0.04 s
f_dynamic_slice
call 0: 0.15 s
call 1: 0.12 s
call 2: 0.14 s
call 3: 0.13 s
call 4: 0.16 s
f_gather
call 0: 0.55 s
call 1: 0.54 s
call 2: 0.51 s
call 3: 0.58 s
call 4: 0.59 s
您可以在這里看到靜態索引 ( lax.slice) 導致編譯后最快的執行。但是,為了生成很多切片,dynamic_slice并且gather避免重復編譯。可能我們應該jnp.split根據重新實作dynamic_slice,但這不會沒有權衡:例如,它會導致(可能更常見?)在少數分裂的情況下放緩,而lax.slice在這兩種情況下會更快初始和后續運行。此外,dynamic_slice僅當每個切片大小相同時才避免重新編譯,因此生成許多不同大小的切片會產生類似于lax.slice.
JAX 開發渠道中積極討論了這些權衡;在PR #12219中可以找到一個與此非常相似的最新示例。如果您想就這個特定問題發表意見,我會邀請您就該主題提交一個新的 jax 問題。
最后一點:如果您真的只是對生成陣列的等長順序切片感興趣,那么您最好只呼叫reshape:
out = x.reshape(len(x) // 10, 10)
結果現在是一個二維陣列,其中每一行對應于上述函式的一個切片,這將遠遠優于任何生成陣列切片串列的方法。
uj5u.com熱心網友回復:
Jax inbult 函式也是 JIT 編譯的
基準測驗 JAX 代碼
JAX 代碼是即時 (JIT) 編譯的。大多數用 JAX 撰寫的代碼都可以以支持 JIT 編譯的方式撰寫,這可以使其運行得更快(請參閱 To JIT or not to JIT)。要從 JAX 獲得最大性能,您應該在最外層的函式呼叫上應用 jax.jit()。
請記住,第一次運行 JAX 代碼時,它會因為正在編譯而變慢。即使您在自己的代碼中不使用 jit 也是如此,因為 JAX 的內置函式也是 JIT 編譯的。
所以第一次運行時,它正在編譯jnp.split(或者至少,編譯 jnp.split 中使用的一些函式)
%%timeit -n1 -r1
jnp.split(array, split_indices)
1min 15s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
第二次,呼叫編譯后的函式
%%timeit -n1 -r1
jnp.split(array, split_indices)
131 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
它相當復雜,呼叫其他jax.numpy函式,所以我認為編譯可能需要相當長的時間(在我的機器上 1 分鐘!)
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/519851.html
標籤:Python表现贾克斯
