我試圖加速一段代碼,在二維陣列的每一列上卷積一維陣列(過濾器)。不知何故,當我用 numba's 運行它時njit,我的速度減慢了 7 倍。我的想法:
- 也許列索引會減慢它的速度,但切換到行索引不會影響性能
- 也許切片索引卷積的結果很慢,但洗掉它并沒有改變任何東西
- 我已經檢查過 numba 是否正確理解所有型別
(在 Windows 10、conda 的 python 3.9.4、numpy 1.12.2、numba 0.53.1 上測驗)
誰能告訴我為什么這段代碼很慢?
import numpy as np
from numba import njit
def f1(a1, filt):
l2 = filt.size // 2
res = np.empty(a1.shape)
for i in range(a1.shape[1]):
res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]
return res
@njit
def f1_jit(a1, filt):
l2 = filt.size // 2
res = np.empty(a1.shape)
for i in range(a1.shape[1]):
res[:, i] = np.convolve(a1[:, i], filt)[l2:-l2]
return res
a1 = np.random.random((6400, 1000))
filt = np.random.random((65))
f1(a1, filt)
f1_jit(a1, filt)
%timeit f1(a1, filt) # 404 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f1_jit(a1, filt) # 2.8 s ± 66.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
uj5u.com熱心網友回復:
問題來自 Numba 的實作np.convolve。這是一個已知問題。事實證明,當前的 Numba 實作比 Numpy (在 Windows 上測驗的版本 <=0.54.1)慢得多。
引擎蓋下
一方面,Numpy 實作呼叫correlate本身執行點積,應該由系統上可用的快速 BLAS 庫實作。在另一方面,Numba執行呼叫_get_inner_prod它使用np.dot的是應該也使用相同的BLAS庫(假設BLAS檢測應該是這樣)...
話雖如此,有多個與點積相關的問題:
首先,如果內部變數_HAVE_BLAS的numba/np/arraymath.py是手動禁用,Numba使用后備實作應該是顯著慢的點積的。然而,事實證明,使用 by 使用的回退點積實作np.convolve比在我的機器上使用 BLAS 包裝器的執行速度快 5 倍!fastmath=True在njitNumba 裝飾器中額外使用該引數可使執行速度整體提高 8.7 倍!下面是測驗代碼:
import numpy as np
import numba as nb
def npConvolve(a, b):
return np.convolve(a, b)
@nb.njit('float64[:](float64[:], float64[:])')
def nbConvolveUncont(a, b):
return np.convolve(a, b)
@nb.njit('float64[::1](float64[::1], float64[::1])')
def nbConvolveCont(a, b):
return np.convolve(a, b)
a = np.random.random(6400)
b = np.random.random(65)
%timeit -n 100 npConvolve(a, b)
%timeit -n 100 nbConvolveUncont(a, b)
%timeit -n 100 nbConvolveCont(a, b)
以下是原始有趣的結果:
With _HAVE_BLAS=True (default):
126 μs ± 292 ns per loop
1.6 ms ± 21.3 μs per loop
1.6 ms ± 18.5 μs per loop
With _HAVE_BLAS=False:
125 μs ± 359 ns per loop
311 μs ± 1.18 μs per loop
268 μs ± 4.26 μs per loop
With _HAVE_BLAS=False and fastmath=True:
125 μs ± 757 ns per loop
327 μs ± 3.69 μs per loop
183 μs ± 654 ns per loop
此外,np_convolveNumba 在內部翻轉一些陣列引數,然后使用具有非平凡步幅(即非 1)的翻轉陣列執行點積。這種非平凡的步幅可能會對點積性能產生影響。更一般地說,任何阻止編譯器知道陣列是連續的轉換肯定會強烈影響性能。實際上,以下測驗顯示了使用 Numba 的點積實作處理連續陣列的影響:
import numpy as np
import numba as nb
def np_dot(a, b):
return np.dot(a, b)
@nb.njit('float64(float64[::1], float64[::1])')
def nb_dot_cont(a, b):
return np.dot(a, b)
@nb.njit('float64(float64[::1], float64[:])')
def nb_dot_stride(a, b):
return np.dot(a, b)
v = np.random.random(128*1024)
%timeit -n 200 np_dot(v, v) # 36.5 μs ± 4.9 μs per loop
%timeit -n 200 nb_dot_stride(v, v) # 361.0 μs ± 17.1 μs per loop (x10 !!!)
%timeit -n 200 nb_dot_cont(v, v) # 34.1 μs ± 2.9 μs per loop
關于 Numpy 和 Numba 的一些一般說明
請注意,Numba 在處理相當大的陣列時幾乎無法加速 Numpy 呼叫,因為 Numba主要在 Python 中重新實作 Numpy 函式并使用JIT 編譯器(LLVM-Lite) 來加速它們,而 Numpy 主要以簡單的方式實作 - C(帶有相當慢的 Python 包裝代碼)。Numpy 代碼使用諸如SIMD 指令之類的低級處理器功能來大大加快許多函式的執行速度。兩者似乎都使用已知高度優化的 BLAS 庫。Numpy 通常更趨于優化,因為 Numpy 目前比 Numba 更成熟:Numpy 有更多的貢獻者在更長的時間內作業。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/380994.html
上一篇:這些結構哪個更好?
