給定一個 3x3 的 F 矩陣和一些 N x 3 的點矩陣,我想有效地計算 E = bT @ F @ a。
F = np.arange(9).reshape(3, 3)
>>> [[0 1 2]
[3 4 5]
[6 7 8]]
a = np.array([[1, 2, 1], [3, 4, 1], [5, 6, 1], [7, 8, 1]])
>>>[[1 2 1]
[3 4 1]
[5 6 1]
[7 8 1]]
b = np.array([[10, 20, 1],[30, 40, 1],[50, 60, 1],[70, 80, 1]])
>>>[[10 20 1]
[30 40 1]
[50 60 1]
[70 80 1]]
預期的輸出是:
E = [388 1434 3120 5446]
我可以用一個簡單的 for 回圈來獲得它,但我想用所有 numpy. 我嘗試重塑 a 和 b 矩陣,但這并沒有完全奏效。
N = b.shape[0] #4
a_reshaped = a.reshape(N, 3, 1)
b_reshaped = b.reshape(1, N, 3)
F_repeated = np.repeat(F[None,:], N, axis=0)
E = b_reshaped @ F_repeated @ a_reshaped
>>>
[[[ 388]
[ 788]
[1188]
[1588]]
[[ 714]
[1434]
[2154]
[2874]]
[[1040]
[2080]
[3120]
[4160]]
[[1366]
[2726]
[4086]
[5446]]]
如果我然后取對角線值,我會得到預期的結果,這是非常低效的。
有什么建議?
編輯:這是我所描述的 for 回圈版本:
E = []
for k in range(N):
error = b[k].T @ F @ a[k]
E.append(error)
E = np.array(E)
>>>[388 1434 3120 5446]
uj5u.com熱心網友回復:
根據@haveaball 的評論,您無法從一系列二維矩陣乘法中獲得一維陣列。
看起來你想要的是:
(b@[email protected]).diagonal()
(或(a@(b@F).T).diagonal())
輸出: array([ 388, 1434, 3120, 5446])
uj5u.com熱心網友回復:
我找到了一個避免制作 NxN 矩陣的解決方案:
partial_mult_b = b @ F
E_prime = partial_mult_b * a
E = np.sum(E_prime, axis=1)
uj5u.com熱心網友回復:
首先,“天真”迭代版本 - 清除操作:
In [67]: c = np.zeros(4,int)
In [68]: for i in range(4):
...: c[i] = b[i]@F@a[i]
...:
In [69]: c
Out[69]: array([ 388, 1434, 3120, 5446])
簡單的 einsum 版本
In [70]: np.einsum('ij,jk,ik->i',b,F,a)
Out[70]: array([ 388, 1434, 3120, 5446])
批處理 matmul 版本
In [71]: b[:,None,:]@F@a[:,:,None]
Out[71]:
array([[[ 388]],
[[1434]],
[[3120]],
[[5446]]])
In [72]: (b[:,None,:]@F@a[:,:,None]).squeeze()
Out[72]: array([ 388, 1434, 3120, 5446])
這個:
In [73]: ((b@F)*a).sum(axis=1)
Out[73]: array([ 388, 1434, 3120, 5446])
與將 擴展einsum為兩個步驟相同:
In [74]: np.einsum('ik,ik->i',np.einsum('ij,jk->ik',b,F),a)
Out[74]: array([ 388, 1434, 3120, 5446])
轉載請註明出處,本文鏈接:https://www.uj5u.com/caozuo/338370.html
下一篇:查找具有nan值的行并將其洗掉
