我正在嘗試從陣列中提取陣列的一部分。
假設我有一個array1shape陣列(M, N, P)。對于我的具體情況,M = 10, N = 5, P = 2000。我有另一個陣列array2shape (M, N, 1),其中包含array1沿最后一個軸的有趣資料的起點。我想從 給出的索引開始提取這個資料的 50 個點array2,有點像這樣:
array1[:, :, array2:array2 50]
我希望得到 shape 的結果(M, N, 50)。不幸的是我收到錯誤:
TypeError: only integer scalar arrays can be converted to a scalar index
當然我也可以通過遍歷陣列來得到結果,但我覺得必須有一個更聰明的方法,因為我經常需要這個。
uj5u.com熱心網友回復:
由于您在每個位置的索引沒有對齊,您可以創建一個掩碼或花哨的索引來提取所需的元素。由于提取的值將是平坦的,因此您將不得不重塑它們。
以下是創建蒙版的方法:
K = 50
mask = np.zeros((M, N, P 1), dtype=np.int8)
np.put_along_axis(mask, array2, 1, axis=-1)
np.put_along_axis(mask, array2 K, -1, axis=-1)
mask.cumsum(axis=-1, out=mask)
mask = mask[..., :-1].view(bool)
我們利用np.int8和np.bool_具有相同專案大小的事實,并將np.cumsum初始掩碼位置傳播到每個軸的最終位置。
剩下的就很簡單了:
array3 = array1[mask].reshape(M, N, K)
在構建掩碼時,您可以通過繞過np.put_along_axis并在適當的情況下使用帶有剪輯的直接索引來避免額外的元素:
mask = np.zeros_like(array1, dtype=np.int8)
r = np.tile(np.arange(M)[:, None, None], (1, N, 1))
c = np.tile(np.arange(N)[None, :, None], (M, 1, 1))
clip_mask = array2 K < P
mask[r, c, array2] = 1
mask[r[clip_mask], c[clip_mask], array2[clip_mask] K] = -1
mask = np.cumsum(mask, axis=-1, out=mask).view(bool)
這一切都非常浪費:要獲得一個 shape 陣列(M, N, K),您正在創建一個 size 的布爾掩碼(M, N, P)以及一些 size 的索引陣列(M, N, 1)、另一個 size 的掩碼,(M, N, 1)然后是這些索引陣列的一些掩碼版本。在for這里使用回圈確實沒有錯,只要您編譯它們,例如使用 cython 或 numba。
uj5u.com熱心網友回復:
您可以使用 array2 中的值與最后一個維度的索引范圍的比較來構建掩碼:
例如:
import numpy as np
M,N,P,k = 4,2,15,3 # yours would be 10,5,2000,50
A1 = np.arange(M*N*P).reshape((M,N,P))
A2 = np.arange(M*N).reshape((M,N,1)) 1
rP = np.arange(P)[None,None,:]
A3 = A1[(rP>=A2)&(rP<A2 k)].reshape((M,N,k))
輸入:
print(A1)
[[[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
[ 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29]]
[[ 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44]
[ 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]]
[[ 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74]
[ 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89]]
[[ 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104]
[105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]]]
print(A2)
[[[1]
[2]]
[[3]
[4]]
[[5]
[6]]
[[7]
[8]]]
輸出:
print(A3)
[[[ 1 2 3]
[ 17 18 19]]
[[ 33 34 35]
[ 49 50 51]]
[[ 65 66 67]
[ 81 82 83]]
[[ 97 98 99]
[113 114 115]]]
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/372117.html
標籤:Python 数组 麻木的 numpy-indexing
上一篇:在Python中制作串列串列
下一篇:矩陣乘法和solve_ivp
