假設我有一個由0到10的隨機浮點陣列成的矩陣Y,形狀為(10,3):
import numpy as np
np.random.seed(99)
Y = np.random.uniform(0, 10, (10, 3))
print(Y)
輸出:
[[6.72278559 4.88078399 8.25495174]
[0.31446388 8.08049963 5.6561742 ]
[2.97622499 0.46695721 9.90627399]
[0.06825733 7.69793028 7.46767101]
[3.77438936 4.94147452 9.28948392]
[3.95454044 9.73956297 5.24414715]
[0.93613093 8.13308413 2.11686786]
[5.54345785 2.92269116 8.1614236 ]
[8.28042566 2.21577372 6.44834702]
[0.95181622 4.11663239 0.96865261] ]
我現在得到了一個矩陣X,其形狀與在Y中加入小的噪音,然后洗行得到的矩陣相同:
X = np.random.normal(Y, scale=0.1)
np.random.shuffle(X)
print(X)
輸出:
[[ 4.04067271 9.90959141 5.19126867]
[5.59873104 2.84109306 8.11175891]
[0.10743952 7.74620162 7.51100441]
[3.60396019 4.91708372 9.07551354]
[0.9400948 4.15448712 1.04187208[/span>]
[2.91884302 0.47222752 10.12700505]
[0.30995155 8.09263241 5.74876947]
[1.112472 8.02092335 1.99767444]
[6.68543696 4.8345869 8.17330513]
[8.38904822 2.11830619 6.42013343]]
現在我想根據Y對矩陣X進行排序按行。我已經知道在每一對匹配的行中的每一對列值之間的差異都不超過0.5的公差。我設法寫了下面的代碼,它作業得很好。
def sort_X_by_Y(X, Y, tol)。
idxs = [next(i for i in range(len(X) ) if all(abs(X[i] - row) <= tol)) for row in Y] 。
return X[idxs] 。
print( sort_X_by_Y(X, Y, tol=0.5)
輸出:
[[ 6.68543696 4.8345869 8.17330513 ]
[0.30995155 8.09263241 5.74876947]
[2.91884302 0.47222752 10.12700505]
[0.10743952 7.74620162 7.51100441]
[3.60396019 4.91708372 9.07551354]
[4.04067271 9.90959141 5.19126867]
[1.112472 8.02092335 1.99767444[/span>]
[5.59873104 2.84109306 8.11175891]
[8.38904822 2.11830619 6.42013343]
[0.9400948 4.15448712 1.04187208]]
然而,在現實中,我正在對(1000, 3)矩陣進行排序,我的代碼太慢了。我覺得應該有更多的numpyic方式來編碼這個。有什么建議嗎?
uj5u.com熱心網友回復:
這是你的演算法的一個矢量版本。對于1000個樣本,它的運行速度比你的實作快~26.5倍。但是一個額外的布爾陣列,其形狀為(1000,1000,3),被創建。有可能在公差范圍內的行會有類似的值,從而選擇了一個錯誤的行。
tol=.5
X[(np.abs(Y[:, np.newaxis] - X) <= tol).all(2).argmax(1)]
輸出
array([[ 6.68543696, 4.8345869 , 8.17330513]。
[0.30995155, 8.09263241, 5.74876947] 。
[2.91884302, 0.47222752, 10.12700505] 。
[0.10743952, 7.74620162, 7.51100441] 。
[3.60396019, 4.91708372, 9.07551354] 。
[4.04067271, 9.90959141, 5.19126867] 。
[1.112472, 8.02092335, 1.99767444] 。
[5.59873104, 2.84109306, 8.11175891] 。
[8.38904822, 2.11830619, 6.42013343] 。
[0.9400948 , 4.15448712, 1.04187208]])
使用L1-norm的更穩健的解決方案
X[np.abs(Y[:, np. newaxis] - X).sum(2).argmin(1)]
或者L2-norm
X[((Y[:, np.newaxis] - X)**2)。 sum(2).argmin(1)]
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/309469.html
標籤:
