我有以下代碼片段:
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
我如何將其概括為任何a(n,3,3),b(n,3)和c(n,3)?
我認為這einsum是要走的路,但我無法弄清楚正確的語法......
uj5u.com熱心網友回復:
您可以廣播或使用 einsum(更好的 einsum):
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
res_broad = (a*b[:,None,:]).sum(2)
res_ein = np.einsum('ijk,ik->ij',a,b)
print(f"broadcast works: {np.allclose(c,res_broad)}")
print(f"einsum works: {np.allclose(c,res_broad)}")
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/485060.html
上一篇:元組到numpy,資料準確性
下一篇:將ndarray轉換為字串陣列
