我在執行 np.argpartition 時遇到問題我有 nd.array
example = np.array([[5,6,7,3,4],[1,2,3,7,5],[6,7,4,2,3],[1,2,3,5,9],[2,3,6,1,2,]])
out: [[5 6 7 3 4]
[1 2 3 7 5]
[6 7 4 2 3]
[1 2 3 5 9]
[2 3 6 1 2]]
我可以通過 np.argsort 獲取排序陣列的索引
print(np.argsort(example))
out:
[[3 4 0 1 2]
[0 1 2 4 3]
[3 4 2 0 1]
[0 1 2 3 4]
[3 0 4 1 2]]
我想使用 np.argsort 來節省一些執行時間,因為在這個陣列的每一行中我只需要 3 個排序元素。我使用此代碼來做到這一點:
print(np.argpartition(example, 3, axis=1))
out: [[3 4 0 1 2]
[1 0 2 4 3]
[3 4 2 0 1]
[1 0 2 3 4]
[3 4 0 1 2]]
我希望每行的前三個索引將與排序陣列中的索引匹配,但事實并非如此 ю 行不通。我不明白我做錯了什么。
uj5u.com熱心網友回復:
np.argpartition(example, k, axis=1)不回傳前 k 個元素的排序陣列。它只回傳索引,使得只有第 (k 1) 個元素被排序。如果您在輸出中看到,只有第 4 個元素與argsort()
如果你想要前三個排序的元素,你必須給出 k 引數的串列
index_array = np.argpartition(example, [0,1,2], axis=1)
print(np.take_along_axis(example,index_array, axis=1)) ##this will give you first 3 sorted elements
轉載請註明出處,本文鏈接:https://www.uj5u.com/qianduan/321532.html
