我有一個 1xN 陣列 A 和一個 2xM 陣列 B。我想制作兩個新的 1xN 陣列
- 檢查 B 的第一列是否在 A 中的布林值
- 如果 B[0,i] 在 A 中,則另一個條目 i 為 B[1,i],否則為 np.nan
我使用的任何方法都需要非常快,因為它會被呼叫很多。我可以使用這個來完成第一部分:對于大型陣列,是否有比 np.isin 更快的方法?
但我很難做第二部分的好方法。這是我到目前為止所得到的(改編上面帖子中的代碼):
import numpy as np
import numba as nb
@nb.jit(parallel=True)
def isinvals(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = list_vals.index(arr[i]) ## THIS LINE IS WAY TOO SLOW
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
N = int(1e5)
M = int(20e3)
num_arr = 100e3
num_vals = 20e3
num_types = 6
arr = np.random.randint(0, num_arr, N)
vals_col1 = np.random.randint(0, num_vals, M)
vals_col2 = np.random.randint(0, num_types, M)
vals = np.array([vals_col1, vals_col2])
%timeit result, result_vals = isinvals(arr,vals)
46.4 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
我在上面標記的線 ( list_vals.index(arr[i])) 是慢速部分。如果我不使用它,我可以制作一個超快速的版本:
@nb.jit(parallel=True)
def isinvals_cheating(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = 0 ## TEMPORARILY SETTING TO 0 TO INDICATE SPEED DIFFERENCE
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
%timeit result, result_vals = isinvals_cheating(arr,vals)
1.13 ms ± 59.7 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
即那條線使它慢了40倍。
有任何想法嗎?我也嘗試過使用 np.where() 但它更慢。
uj5u.com熱心網友回復:
vals[0, idx]假設 OP 的解決方案給出了預期的結果,因為對于具有不同對應值的非唯一值,這個問題似乎模棱兩可vals[1, idx]。查找表更快,但需要len(arr)額外的空間。
@nb.njit # tested with numba 0.55.1
def isin_nb(arr, vals):
lookup = np.empty(len(arr), np.float32)
lookup.fill(np.nan)
lookup[vals[0, ::-1]] = vals[1, ::-1]
res_val = lookup[arr]
return ~np.isnan(res_val), res_val
使用問題中使用的示例資料
res, res_val = isin_nb(arr, vals)
# %timeit 1000 loops, best of 5: 294 μs per loop
斷言相等的結果
np.testing.assert_equal(res, result)
np.testing.assert_equal(res_val, result_vals)
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/467242.html
