技術背景
Vmap是一種在python里面經常提到的向量化運算的功能,比如之前大家常用的就是numba和jax中的向量化運算的介面,雖然numpy中也使用到了向量化的運算,比如計算兩個numpy陣列的加和,就是一種向量化的運算,但是在numpy中模塊封裝的較好,定制化程度低,但是使用便捷,只需要呼叫最上層的介面即可,現在最新版本的mindspore也已經推出了vmap的功能,像mindspore、numba還有jax,與numpy的最大區別就是,需要在使用程序中對需要向量化運算的函式額外嵌套一層vmap的函式,這樣就可以實作只對需要向量化運算的模塊進行擴展,用一個公式來理解向量化運算的話就是:
\[a_1+b_1=c_1\\ a_2+b_2=c_2\\ .\\ .\\ .\\ a_n+b_n=c_n\\ \Downarrow\\ \vec{a}+\vec{b}=\vec{c} \]安裝最新版MindSpore
關于jax中的vmap使用案例,可以參考前面介紹的LINCS約束演算法實作和SETTLE約束演算法批量化實作這兩篇文章,都有使用到jax的vmap功能,這里我們著重介紹的是MindSpore中最新實作的vmap功能,首先我們需要安裝mindspore最新的Nightly版本,其對應的是MindSpore的Gitee倉庫中的master分支,具體安裝指令可以參考其官方鏈接:
因為我們本地已經安裝過Mindspore的舊版本,因此還需要在安裝指令之后加上--upgrade操作,否則會導致系統誤以為本地已經安裝成功,不會執行安裝的操作:
$ python3 -m pip install mindspore-cuda11-dev -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade
Vmap功能測驗
這里我們先來看一個比較簡單的示例:
In [1]: from mindspore import Tensor
In [2]: from mindspore.ops.functional import vmap
In [3]: y = lambda a,b: a+b
In [4]: A = Tensor([1,2,3])
In [5]: B = Tensor([3,4,5])
In [6]: vmap_y = vmap(y,in_axes=(0,0))
In [7]: y(A[0],B[0]) # 元素加和
Out[7]: Tensor(shape=[], dtype=Int64, value= https://www.cnblogs.com/dechinphy/p/4)
In [8]: vmap_y(A,B) # 矢量加和
Out[8]: Tensor(shape=[3], dtype=Int64, value= [4, 6, 8])
在上面的這個示例中,我們定義了一個加法函式y,作用就是把輸入的兩個物件相加,這里需要注意的是,如果輸入給y的是兩個Mindspore的Tensor物件,那么會直接回傳兩個Tensor對應位置相加的結果,但是如果輸入給y的是兩個普通python的list,則輸出的結果會是兩個list的拼接,這跟不同型別的加法的實作方式有關,在文末總結中會進行解釋,這里我們只是想說明:y本身是一個元素加和的函式,可以通過vmap使其稱為矢量加和的函式,關于輸入的in_axes引數,指的是擴展的維度,比如我們寫了一個支持\((A,A)\times(A,1)\)維度的函式,如果把in_axes引數設定為0,那么就可以得到一個支持計算\((B,A,A)\times(B,A,1)\)維度的函式,其中in_axes引數,決定的是被擴展的維度B所在的位置,這一點我們可以看一下vmap的官方示例:
在這個案例中,也是定義了一個普通的加和函式,通過vmap去擴展不同的維度,大致的計算邏輯為:
\[(A,)+(A,)+(A,)\\ \Downarrow^{in\_axes=(0,1,None)}\\ (B,A)+(A,B)+(A,)=(B,A)+(B,A)+(1,A)=(B,A)\\ \Downarrow^{out\_axes=1}\\ (A,B) \]其實這個程序中關于in_axes是比較容易可以理解的,但是這個out_axes有時候會讓人難以捉摸,在github上專門有人提出了這個issue并有人做出了解釋:
結合上面的案例,其實out_axes就是決定了擴展的維度B在結果中的位置,比如out_axes=1,所對應的結果中就是\((x,B,x,...x)\),也就是說,其不影響計算的結果,但是有可能會對計算結果進行轉置操作,在MindSpore和Numpy中稱為swap_axes,
總結概要
本文介紹了華為推出的深度學習框架MindSpore中最新支持的vmap功能函式,可以用于向量化的計算,本質上的主要作用是替代并加速python中的for回圈的操作,最早是在numba和pytroch、jax中對vmap功能進行了支持,其實numpy中的底層計算也用到了向量化的運算,因此速度才如此之快,vmap在python中更多的是與即時編譯功能jit一同使用,能夠起到簡化編程的同時對性能進行極大程度的優化,尤其是python中的for回圈的優化,但是對于一些numpy、jax或者MindSpore中已有的算子而言,還是建議直接使用其已經實作的算子,而不是vmap再手寫一個,
著作權宣告
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/ms-vmap.html
作者ID:DechinPhy
更多原著文章請參考:https://www.cnblogs.com/dechinphy/
打賞專用鏈接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
騰訊云專欄同步:https://cloud.tencent.com/developer/column/91958
參考鏈接
- https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/functional.py#L845
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/451294.html
標籤:其他
下一篇:NLP 自然語言處理實戰
