我正在嘗試構建一個函式,該函式回傳資料點和質心之間最短距離的索引。但是我遇到了一個錯誤IndexError: arrays used as indices must be of integer (or boolean) type。
import numpy as np
b = np.ndarray((3, 2, 1, 4))
DEFAULT_CENTROIDS = np.array([[5.664705882352942, 3.0352941176470587, 3.3352941176470585, 1.0176470588235293],
[5.446153846153847, 3.2538461538461543, 2.9538461538461536, 0.8846153846153846],
[5.906666666666667, 2.933333333333333, 4.1000000000000005, 1.3866666666666667],
[5.992307692307692, 3.0230769230769234, 4.076923076923077, 1.3461538461538463],
[5.747619047619048, 3.0714285714285716, 3.6238095238095243, 1.1380952380952383],
[6.161538461538462, 3.030769230769231, 4.484615384615385, 1.5307692307692309],
[6.294117647058823, 2.9764705882352938, 4.494117647058823, 1.4],
[5.853846153846154, 3.215384615384615, 3.730769230769231, 1.2076923076923078],
[5.52857142857143, 3.142857142857143, 3.107142857142857, 1.007142857142857],
[5.828571428571429, 2.9357142857142855, 3.664285714285714, 1.1]])
def get_closest(data_point: np.ndarray, centroids: np.ndarray):
"""
Takes a data_point and a nd.array of multiple centroids and returns the index of the centroid closest to data_point
by computing the euclidean distance for each centroid and picking the closest.
"""
N = centroids.shape[0]
dist = np.empty(N)
for i in centroids:
dist[i] = np.linalg.norm(centroids[i]-data_point)
index_min = np.argmin(dist)
return index_min # the index of the centroid closest to the datapoint
get_closest(b,DEFAULT_CENTROIDS)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-52-7c41e5374a5e> in <module>
----> 1 get_closest(b,DEFAULT_CENTROIDS)
<ipython-input-51-8e1cd568d0df> in get_closest(data_point, centroids)
11 dist = np.empty(N)
12 for i in centroids:
---> 13 dist[i] = np.linalg.norm(centroids[i]-data_point)
14 index_min = np.argmin(dist)
15 return index_min # the index of the centroid closest to the datapoint
IndexError: arrays used as indices must be of integer (or boolean) type
我不太明白錯誤或我的代碼錯誤的原因。建議?提前致謝。
uj5u.com熱心網友回復:
for i in centroids:-> centroids[i]。
i是 的一個元素centroids,而不是它的索引。因此,當您嘗試按浮點陣列進行索引時,它不起作用。
對于您似乎想要的東西,我認為enumerate這將是您最好的選擇。
def get_closest(data_point: np.ndarray, centroids: np.ndarray):
"""
Takes a data_point and a nd.array of multiple centroids and returns the index of the centroid closest to data_point
by computing the euclidean distance for each centroid and picking the closest.
"""
N = centroids.shape[0]
dist = np.empty(N)
for i, c in enumerate(centroids):
dist[i] = np.linalg.norm(c - data_point)
index_min = np.argmin(dist)
return index_min # the index of the centroid closest to the datapoint
也就是說,在這種情況下,您可能最好使用KDTreeif centroidsis large 或者您的目標是速度。
from scipy.spatial import KDTree
centroid_KDTree = KDTree(centroids)
dist, index_min = centroid_KDTree.query(data_point)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/464599.html
