einsum函式說明
pytorch檔案說明:\(torch.einsum(equation, **operands)\) 使用基于愛因斯坦求和約定的符號,將輸入operands的元素沿指定的維數求和,einsum允許計算許多常見的多維線性代數陣列運算,方法是基于愛因斯坦求和約定以簡寫格式表示它們,主要是省略了求和號,總體思路是在箭頭左邊用一些下標標記輸入operands的每個維度,并在箭頭右邊定義哪些下標是輸出的一部分,通過將operands元素與下標不屬于輸出的維度的乘積求和來計算輸出,其方便之處在于可以直接通過求和公式寫出運算代碼,
# 矩陣乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等價操作 torch.mm(a, b)
兩個基本概念,自由索引/自由標(Free indices)和求和索引/啞標(Summation indices):
- 自由索引,出現在箭頭右邊的索引
- 求和索引,只出現在箭頭左邊的索引,表示中間計算結果需要這個維度上求和之后才能得到輸出,
接著是介紹三潭訓本規則:
- 規則一,equation 箭頭左邊,在不同輸入之間重復出現的索引表示,把輸入張量沿著該維度做乘法操作,比如還是以上面矩陣乘法為例, "ik,kj->ij",k 在輸入中重復出現,所以就是把 a 和 b 沿著 k 這個維度作相乘操作;
- 規則二,只出現在 equation 箭頭左邊的索引,表示中間計算結果需要在這個維度上求和,也就是上面提到的求和索引;
- 規則三,equation 箭頭右邊的索引順序可以是任意的,比如上面的 "ik,kj->ij" 如果寫成 "ik,kj->ji",那么就是回傳輸出結果的轉置,用戶只需要定義好索引的順序,轉置操作會在 einsum 內部完成,
兩條特殊規則:
- equation 可以不寫包括箭頭在內的右邊部分,那么在這種情況下,輸出張量的維度會根據默認規則推導,就是把輸入中只出現一次的索引取出來,然后按字母表順序排列,比如上面的矩陣乘法 "ik,kj->ij" 也可以簡化為 "ik,kj",根據默認規則,輸出就是 "ij" 與原來一樣;
- equation 中支持 "..." 省略號,用于表示用戶并不關心的索引,詳見下方轉置例子
單運算元
獲取對角線元素diagonal
einsum 可以不做求和,舉個例子,獲取二維方陣的對角線元素,結果放入一維向量,
\[A_i = B_{ii} \]上面,A 是一維向量,B 是二維方陣,使用 einsum 記法,可以寫作 ii->i
torch.einsum('ii->i', torch.randn(4, 4))
# 以下操作互相等價
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)
跡trace
求解矩陣的跡(trace),即對角線元素的和,
\[t = \Sigma_{i=1}^{n} A_{ii} \]t 是常量,A 是二維方陣,按照前面的做法,省略 ΣΣ,左右兩邊對調,省去矩陣和 t,剩下的就是ii->或省略箭頭ii
torch.einsum('ii', torch.randn(4, 4))
矩陣轉置
\[A_{ij} = B_{ji} \]A 和 B 都是二維方陣,einsum 可以表達為 ij->ji,
torch.einsum('ij -> ji',a)
pytorch 中,還支持省略前面的維度,比如,只轉置最后兩個維度,可以表達為 ...ij->...ji,下面展示了一個含有四個二維矩陣的三維矩陣,轉置三維矩陣中的每個二維矩陣,
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])
# 等價操作
A.permute(0,1,3,2)
A.transpose(2,3)
求和
\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
列求和:
\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
# 等價操作
torch.sum(a, 0) # (dim引數0) means the dimension or dimensions to reduce.
雙運算元
矩陣乘法
\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]第一個學習的 einsum 運算式是,ik,kj->ij,前面提到過,愛因斯坦求和記法可以理解為懶人求和記法,將上述公式中的 ΣΣ 去掉,并且將左右兩邊對調一下,省去矩陣之后,剩下的就是 ik,kj->ij 了,
torch.einsum('ik,kj->ij', a, b)
# 可用兩個矩陣測驗以下矩陣乘法操作互相等價
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩陣-向量相乘
\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
批量矩陣乘 batch matrix multiplication
\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
# 等價操作
torch.bmm(As, Bs)
向量內積 dot
\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)
# 等價操作
torch.dot(a, b)
矩陣內積 dot
\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
哈達瑪積
\[C_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
外積 outer
\[C_{i j}=a_{i} b_{j} \]a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
einsum其他規則和例子判斷:
- 輸入中多次出現的字符,將被用作求和,例子,
kj,ji完整的運算式是kj,ji->ik,矩陣乘法再相乘, - 輸出可以指定,但是輸出中的每個字符必須在輸入中出現至少一次,輸出的每個字符在輸出中只能出現最多一次,例子,
ab->aa是非法的,ab->c是非法的,ab->a是合法的, - 省略符
...是用來跳過部分維度,例子,...ij,...jk表示 batch 矩陣乘法, - 在輸出沒有指定的情況下,省略符優先級高于普通字符,例子,
b...a完整的運算式是b...a->...ab,可以將一個形狀為(a,b,c)的矩陣變為形狀為(b,c,a)的矩陣, - 允許多個矩陣輸入,運算式中使用逗號分開不同矩陣輸入的下標,例子,
i,i,i表示將三個一維向量按位相乘,并相加, - 除了箭頭,其他任何地方都可以加空格,例子,
i j , j k -> ik是合法的,ij,jk - > ik是非法的, - 輸入的運算式,維度需要和輸入的矩陣對上,不能多也不能少,比如一個 shape 為
(4,3,3)的矩陣,運算式ab->a是非法的,abc->是合法的,
實際使用
實作multi headed attention
https://nn.labml.ai/transformers/mha.html
如何優雅地實作多頭自注意力
計算注意力score:
\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]# q k v均為 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解為ibhd,jbhd->ibhj->ijbh
計算attention輸出:
\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]
參考文獻:
https://zhuanlan.zhihu.com/p/361209187
如何優雅地實作多頭自注意力
https://rockt.github.io/2018/04/30/einsum **
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/470644.html
標籤:其他
上一篇:C++進階-3-5-list容器
