我有一個形狀為 arr.shape = N,M,M 的 numpy 陣列。
我想訪問每個 M,M 陣列的下三角形。我嘗試使用
arr1 = arr[:,np.tril_indices(M,-1)]
arr1 = arr[:][np.tril_indices(M,-1)]
等等,在第一種情況下內核死亡,而在第二種情況下,我收到一個錯誤訊息:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-23-1b36c5b12706> in <module>
----> 1 arr1 = arr[:][np.tril_indices(M,-1)]
IndexError: index 6 is out of bounds for axis 0 with size 6
在哪里
N=6
為了澄清我想找到每個 M,M 陣列(N 個這樣的實體)的下三角形中的所有元素,并將結果保存在一個新的形狀陣列中:
arr1.shape = (N,(M*(M-1))/2)
編輯:
雖然 np.tril(arr) 有效,但它會產生一個陣列
arr1 = np.tril(arr)
arr1.shape
#(N,M,M)
我希望生成的陣列具有指定的形狀,即我不想要陣列的上部
謝謝
uj5u.com熱心網友回復:
import numpy as np
a = np.random.rand(2, 5, 5)
#array([[[0.28212197, 0.29827562, 0.05151153, 0.90448236, 0.07521404],
# [0.38938978, 0.67007919, 0.83561652, 0.5950061 , 0.73563179],
# [0.77515285, 0.31973392, 0.91861436, 0.87386527, 0.85917542],
# [0.12588184, 0.09173029, 0.28577701, 0.4884228 , 0.07183555],
# [0.68656271, 0.19941039, 0.07924489, 0.15046004, 0.91011737]],
#
# [[0.18662788, 0.45745028, 0.14557573, 0.22425571, 0.14204739],
# [0.44502694, 0.85773626, 0.78554919, 0.07306402, 0.14608384],
# [0.70620254, 0.81497515, 0.09397011, 0.32053184, 0.255485 ],
# [0.50139688, 0.51539848, 0.24719375, 0.80708819, 0.39685176],
# [0.94052069, 0.53927081, 0.39567362, 0.06065674, 0.53479994]]])
np.tril(a)
#array([[[0.28212197, 0. , 0. , 0. , 0. ],
# [0.38938978, 0.67007919, 0. , 0. , 0. ],
# [0.77515285, 0.31973392, 0.91861436, 0. , 0. ],
# [0.12588184, 0.09173029, 0.28577701, 0.4884228 , 0. ],
# [0.68656271, 0.19941039, 0.07924489, 0.15046004, 0.91011737]],
#
# [[0.18662788, 0. , 0. , 0. , 0. ],
# [0.44502694, 0.85773626, 0. , 0. , 0. ],
# [0.70620254, 0.81497515, 0.09397011, 0. , 0. ],
# [0.50139688, 0.51539848, 0.24719375, 0.80708819, 0. ],
# [0.94052069, 0.53927081, 0.39567362, 0.06065674, 0.53479994]]])
如果要洗掉零并將其展平為(2, 15)陣列(請注意,每個下三角形陣列中有 10 個零)-
a_no_zeros = np.array([el
for mat in a_lower
for row in mat
for el in row
if el > 0
]).reshape(2, 15)
uj5u.com熱心網友回復:
在使用這tri...組函式時,檢查源代碼會很有用。它們都是 python,并基于np.tri.
制作一個小樣本陣列 - 來說明和驗證答案:
In [205]: arr = np.arange(18).reshape(2,3,3) # arange(1,19) might be better
In [206]: arr
Out[206]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
tril 將上三角值設定為 0。它在這種情況下有效,但沒有記錄到 3d 陣列的應用程式。
In [207]: np.tril(arr)
Out[207]:
array([[[ 0, 0, 0],
[ 3, 4, 0],
[ 6, 7, 8]],
[[ 9, 0, 0],
[12, 13, 0],
[15, 16, 17]]])
但是在代碼中 if first 從最后 2 個維度構造一個布爾掩碼:
In [208]: mask = np.tri(*arr.shape[-2:], dtype=bool)
In [209]: mask
Out[209]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])
并用于np.where將一些值設定為 0。這在 3d 情況下通過廣播起作用。 mask并arr匹配最后兩個維度,因此mask可以broadcast匹配:
In [210]: np.where(mask, arr, 0)
Out[210]:
array([[[ 0, 0, 0],
[ 3, 4, 0],
[ 6, 7, 8]],
[[ 9, 0, 0],
[12, 13, 0],
[15, 16, 17]]])
你tril_indices只是這個面具的索引:
In [217]: np.nonzero(mask) # aka np.where
Out[217]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [218]: np.tril_indices(3)
Out[218]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
它們不能直接用于索引arr:
In [220]: arr[np.tril_indices(3)].shape
Traceback (most recent call last):
File "<ipython-input-220-e26dc1f514cc>", line 1, in <module>
arr[np.tril_indices(3)].shape
IndexError: index 2 is out of bounds for axis 0 with size 2
In [221]: arr[:,np.tril_indices(3)].shape
Out[221]: (2, 2, 6, 3)
但是解壓兩個索引陣列:
In [222]: I,J = np.tril_indices(3)
In [223]: I,J
Out[223]: (array([0, 1, 1, 2, 2, 2]), array([0, 0, 1, 0, 1, 2]))
In [224]: arr[:,I,J]
Out[224]:
array([[ 0, 3, 4, 6, 7, 8],
[ 9, 12, 13, 15, 16, 17]])
也可以直接使用布爾掩碼:
In [226]: arr[:,mask]
Out[226]:
array([[ 0, 3, 4, 6, 7, 8],
[ 9, 12, 13, 15, 16, 17]])
The base np.tri works by simply doing an outer >= on indices
In [231]: m = np.greater_equal.outer(np.arange(3),np.arange(3))
In [232]: m
Out[232]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])
In [234]: np.arange(3)[:,None]>=np.arange(3)
Out[234]:
array([[ True, False, False],
[ True, True, False],
[ True, True, True]])
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/355593.html
標籤:Python 数组 麻木的 numpy-ndarray numpy-ufunc
