假設我有一個N = 10隨機浮點數的 Numpy 陣列:
import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)
out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963
5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]
我想找到所有唯一的數字對,它們彼此之間的差異不超過容差tol = 1.(即絕對差異 <= 1)。具體來說,我想獲得所有唯一的索引對。每個關閉對的索引都應該排序,所有關閉對都應該按第一個索引排序。我設法撰寫了以下作業代碼:
def all_close_pairs(arr, tol=1.):
res = set()
for i, x1 in enumerate(arr):
for j, x2 in enumerate(arr):
if i == j:
continue
if np.isclose(x1, x2, rtol=0., atol=tol):
res.add(tuple(sorted([i, j])))
res = np.array(list(res))
return res[res[:,0].argsort()]
print(all_close_pairs(arr, tol=1.))
out[2]: [[1 5]
[2 4]
[3 7]
[3 9]
[7 9]]
然而,實際上我有一個N = 1000數字陣列,由于嵌套的 for 回圈,我的代碼變得非常慢。我相信使用 Numpy 矢量化有更有效的方法來做到這一點。有誰知道在 Numpy 中執行此操作的最快方法?
uj5u.com熱心網友回復:
這是一個純 numpy 操作的解決方案。在我的機器上它似乎很快,但我不知道我們在尋找什么樣的速度。
def all_close_pairs(arr, tol=1.):
N = arr.shape[0]
# get indices in the array to consider using meshgrid
pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
# filter out pairs so we get indices in increasing order
pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
# compare indices in your array for closeness
is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
return pair_coords[is_close, :]
uj5u.com熱心網友回復:
一種有效的解決方案是首先使用對輸入值進行排序index = np.argsort()。然后,您可以使用 生成排序陣列arr[index],然后如果快速連續陣列上的對數很小,則在準線性時間內迭代接近值。如果對的數量很大,那么由于生成的對的平方數,復雜度是二次的。由此產生的復雜性是:其中是輸入陣列的大小,是生成的對數。O(n log n m)nm
要找到彼此接近的值,一種有效的方法是使用Numba迭代該值。實際上,雖然在 Numpy 中可能是可能的,但由于要比較的值的可變數量,它可能效率不高。這是一個實作:
import numba as nb
@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
res = []
for i in range(arr.size):
val = arr[i]
# Iterate over the close numbers (only once)
for j in range(i 1, arr.size):
# Sadly neither np.isclose or np.abs are implemented in Numba so far
if max(val, arr[j]) - min(val, arr[j]) >= tol:
break
res.append((i, j))
if len(res) == 0: # No pairs: we need to help Numpy to know the shape
return np.empty((0, 2), dtype=np.int32)
return np.array(res, dtype=np.int32)
最后,需要更新索引以參考未排序陣列中的索引,而不是已排序陣列中的索引。您可以使用index[result].
這是結果代碼:
index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])
這是結果(順序與問題中的順序不同,但您可以根據需要對其進行排序):
array([[9, 3],
[9, 7],
[3, 7],
[1, 5],
[4, 2]])
提高演算法的復雜度
如果您需要更快的演算法,那么您可以使用另一種輸出格式:您可以為每個輸入值提供接近目標輸入值的最小/最大范圍的值。要查找范圍,您可以np.searchsorted對已排序的陣列使用二分搜索(請參閱:)。生成的演算法在 中運行O(n log n)。但是,您無法獲得未排序陣列中的索引,因為該范圍將是不連續的。
基準
以下是在我的機器上隨機輸入 1000 項且容差為 1.0 的性能結果:
Reference implementation: ~17000 ms (x 1)
Angelicos' implementation: 1773 ms (x ~10)
Rivers' implementation: 122 ms (x 139)
Rchome's implementation: 20 ms (x 850)
Chris' implementation: 4.57 ms (x 3720)
This implementation: 0.67 ms (x 25373)
uj5u.com熱心網友回復:
有點晚了,但是一個全麻的解決方案:
import numpy as np
def close_enough( arr, tol = 1 ):
result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol ), 1))
return np.swapaxes( result, 0, 1 )
擴展以解釋正在發生的事情
def close_enough( arr, tol = 1 ):
bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
# is_close generates a square array after comparing all elements with all elements.
bool_arr = np.triu( bool_arr, 1 )
# Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal
# and all elements below and to the left.
result = np.where( bool_arr ) # Return the row and column indices for Trues
return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns
當 N = 1000 時,arr = 一個浮點陣列
%timeit close_enough( arr, tol = 1 )
14.1 ms ± 28.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [19]: %timeit all_close_pairs( arr, tol = 1 )
54.3 ms ± 268 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()
# True
uj5u.com熱心網友回復:
問題是您的代碼具有 O(n*n)(二次)復雜度。為了降低復雜性,您可以嘗試先對專案進行排序:
def all_close_pairs(arr, tol=1.):
res = set()
arr = sorted(enumerate(arr), key=lambda x: x[1])
for (idx1, (i, x1)) in enumerate(arr):
for idx2 in range(idx1-1, -1, -1):
j, x2 = arr[idx2]
if not np.isclose(x1, x2, rtol=0., atol=tol):
break
indices = sorted([i, j])
res.add(tuple(indices))
return np.array(sorted(res))
但是,這僅在您的值范圍遠大于容差時才有效。
您可以通過使用2 pointers策略進一步改進這一點,但整體復雜性將保持不變。
uj5u.com熱心網友回復:
您可以首先使用 itertools.combinations 創建組合:
def all_close_pairs(arr, tolerance):
pairs = list(combinations(arr, 2))
indexes = list(combinations(range(len(arr)), 2))
all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <= tolerance]
return all_close_pairs_indexes
現在,對于 N=1000,您只需比較 499500 對而不是 100 萬對。
這個怎么運作:
我們首先通過 itertools.combinations 創建對。
然后,我們創建它們的索引串列。
出于速度原因,我們使用串列推導式而不是 for 回圈。
在這個理解中,我們迭代所有對,使用
enumerate所以我們可以獲得對的索引,我們計算對中數字的絕對差,如果檢查它是否小于或等于tolerance。如果絕對差小于或等于
tolerance,我們通過索引串列獲得對數字的索引,并將它們添加到我們的最終串列中。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/368761.html
上一篇:在C中下溢時損壞的fgets輸入
