這是將 np.linalg.multi_dot() 函式與 functools.reduce(np.matmul, Nx2x2_arrays) 等 Nx2x2 陣列一起使用的合理方法嗎?請看下面的例子。
import numpy as np
from functools import reduce
m1 = np.array(range(16)).reshape(4, 2, 2)
m2 = m1.copy()
m3 = m1.copy()
reduce(np.matmul, (m1, m2, m3))
結果 - 4x2x2 陣列:
array([[[ 6, 11],
[ 22, 39]],
[[ 514, 615],
[ 738, 883]],
[[ 2942, 3267],
[ 3630, 4031]],
[[ 8826, 9503],
[10234, 11019]]])
如您所見, np.matmul 將 4x2x2 3-D 陣列視為 2x2 矩陣的 1-D 陣列。我可以使用 np.linalg.multi_dot() 而不是 reduce(np.matmul) 來做同樣的事情,如果是,它會導致任何性能改進嗎?
uj5u.com熱心網友回復:
np.linalg.multi_dot() 嘗試通過找到導致總體乘法最少的點積順序來優化操作。
由于您所有的矩陣都是方陣,點積的順序無關緊要,您將始終得到相同數量的乘法。
在內部,np.linalg.multi_dot()不運行任何 C 代碼,而只是呼叫np.dot(),因此您可以執行相同的操作:
functools.reduce(np.matmul, (m1, m2, m3))
或者干脆
m1 @ m2 @ m3
uj5u.com熱心網友回復:
您還可以使用np.einsum():
np.einsum('ijk,ikl,ilm->ijm',m1,m2,m3)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/315967.html
上一篇:BigQuery:如何將陣列的相似記錄分組到逗號分隔的欄位中?
下一篇:處理嵌套陣列的回圈
