我有一個(N,)浮點數 ( arr)陣列,但我只關心 >= 給定的條目threshold。我可以獲得這樣的面具:
mask = (arr >= threshold)
現在我想要一個(N,2)相應切片索引的陣列。
例如,如果arr = [0, 0, 1, 1, 1, 0, 1, 1, 0, 1]和threshold = 1,那么mask = [False, False, True, True, True, False, True, True, False, True],我想要索引[ [2, 5], [6, 8], [9, 10] ](我可以用它arr[2:5], arr[6:8], arr[9:10]來獲取段 where arr >= threshold)。
目前,我有一個丑陋的 for 回圈解決方案,它True在將相應的切片索引附加到串列之前跟隨每一段。有沒有更簡潔易讀的方法來實作這個結果?
uj5u.com熱心網友回復:
您可以將 itertools groupby 與key引數一起使用,enumerate以獲取分組。如果組值是全部,True您可以取第一個和最后一個 1 值。
from itertools import groupby
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
idx = []
for group,data in groupby(enumerate((arr >= threshold)), key=lambda x:x[1]):
d = list(data)
if all(x[1]==True for x in d):
idx.append([d[0][0], d[-1][0] 1])
輸出
[[2, 5], [6, 8], [9, 10]]
uj5u.com熱心網友回復:
您可以使用組合np.flatnonzero和np.diff:
indexes = np.flatnonzero(np.diff(np.append(arr >= threshold, 0))) 1
indexes = list(zip(indexes[0::2], indexes[1::2]))
輸出:
>>> indexes
[(2, 5), (6, 8), (9, 10)]
uj5u.com熱心網友回復:
您可以通過將掩碼布林值與其后繼者進行比較來使用掩碼計算開始和結束索引的串列。然后連接開始和結束以形成范圍(全部使用 numpy 方法矢量化):
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
mask = arr >= threshold
starts = np.argwhere(np.insert(mask[:-1],0,False)<mask)[:,0]
ends = np.argwhere(np.append(mask[1:],False)<mask)[:,0] 1
indexes = np.stack((starts,ends)).T
print(starts) # [2 6 9]
print(ends) # [5 8 10]
print(indexes)
[[ 2 5]
[ 6 8]
[ 9 10]]
如果你想在元組的 Python 串列中得到結果:
indexes = list(zip(starts,ends)) # [(2, 5), (6, 8), (9, 10)]
如果您不需要(或不想)使用 numpy,您可以使用 itertools 中的 groupby 直接從 arr 獲取范圍:
from itertools import groupby
indexes = [ (t[1],t[-1] 1) for t,t[1:] in
groupby(range(len(arr)),lambda i:[arr[i]>=threshold]) if t[0]]
print(indexes)
[(2, 5), (6, 8), (9, 10)]
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/391873.html
上一篇:如何從陣列的每一列中減去第一列
