三種非線性激活函式sigmoid、tanh、ReLU,
sigmoid: y = 1/(1 + e-x)
tanh: y = (ex - e-x)/(ex + e-x)
ReLU:y = max(0, x)

在隱藏層,tanh函式要優于sigmoid函式,可以看作是sigmoid的平移版本,優勢在于其取值為 [-1, 1],資料的平均值為0,而sigmoid的平均值為0.5,有類似資料中心化的效果,
但在輸出層,sigmoid可能會優于tanh,原因在于我們希望輸出結果的概率落在0~1之間,比如二元分類問題,sigmoid可以作為輸出層的激活函式,
在實際情況中,特別是在訓練深層網路時,sigmoid和tanh會在端值趨近飽和,造成訓練速度減慢,故深層網路的激活函式多是采用ReLU,淺層網路可以采用sigmoid和tanh函式,
為弄清在反向傳播中如何進行梯度下降,來看一下三個函式的求導程序:
1. sigmoid求導
sigmoid函式定義為 y = 1/(1 + e-x) = (1 + e-x)-1
相關的求導公式:(xn)' = n * xn-1 和 (ex)' = ex
應用鏈式法則,其求導程序為:
dy/dx = -1 * (1 + e-x)-2 * e-x * (-1)
= e-x * (1 + e-x)-2
= (1 + e-x - 1) / (1 + e-x)2
= (1 + e-x)-1 - (1 + e-x)-2
= y - y2
= y(1 -y)
2. tanh求導
tanh函式定義為 y = (ex - e-x)/(ex + e-x)
相關的求導公式:(u/v)' = (u' v - uv') / v2
應用鏈式法則,其求導程序為:
dy/dx = ( (ex - e-x)' * (ex + e-x) - (ex - e-x) * (ex + e-x)' ) / (ex + e-x)2
= ( (ex - (-1) * e-x) * (ex + e-x) - (ex - e-x) * (ex + (-1) * e-x) ) / (ex + e-x)2
= ( (ex + e-x)2 - (ex - e-x)2 ) / (ex + e-x)2
= 1 - ( (ex - e-x)/(ex + e-x) )2
= 1 - y2
3. ReLU求導
ReLU函式定義為 y = max(0, x)
簡單地推導得 當x <0 時,dy/dx = 0; 當 x >= 0時,dy/dx = 1
接下來著重討論下ReLU
在深度神經網路中,通常選擇線性整流函式(ReLU,Rectified Linear Units)作為神經元的激活函式,ReLU源于對動物神經科學的研究,2001年,Dayan 和 Abbott 從生物學角度模擬出了腦神經元接受信號更精確的激活模型,如圖:

其中橫軸是刺激電流,縱軸是神經元的放電速率,同年,Attwell等神經學科學家通過研究大腦的能量消耗程序,推測神經元的作業方式具有稀疏性和分布性;2003年,Lennie等神經學科學家估測大腦同時被激活的神經元只有1~4%,這進一步表明了神經元作業的稀疏性,
那么,ReLU是如何模擬神經元作業的呢

從上圖可以看出,ReLU其實是分段線性函式,把所有的負值都變為0,正值不變,這種性質被稱為單側抑制,因為單側抑制的存在,才使得神經網路中的神經元也具有了稀疏激活性,尤其在深度神經網路中(如CNN),當模型增加N層之后,理論上ReLU神經元的激活率將降低2的N次方倍,或許有人會問,ReLU的函式影像為什么非得長成這樣子,其實不一定這個樣子,只要能起到單側抑制的作用,無論是鏡面翻轉還是180°翻轉,最終神經元的輸入也只是相當于加上了一個常數項系數,并不會影響模型的訓練結果,之所以這樣定義,或許是為了符合生物學角度,便于我們理解吧,
這種稀疏性有什么作用呢?因為我們的大腦作業時,總有一部分神經元處于活躍或抑制狀態,與之類似,當訓練一個深度分類模型時,和目標相關的特征往往也就幾個,因此通過ReLU實作稀疏后的模型能夠更好地挖掘相關特征,使網路擬合訓練資料,
相比其他激活函式,ReLU有幾個優勢:(1)比起線性函式來說,ReLU的表達能力更強,尤其體現在深度網路模型中;(2)較于非線性函式,ReLU由于非負區間的梯度為常數,因此不存在梯度消失問題(Vanishing Gradient Problem),使得模型的收斂速度維持在一個穩定狀態,(注)梯度消失問題:當梯度小于1時,預測值與真實值之間的誤差每傳播一層就會衰減一次,如果在深層模型中使用sigmoid作為激活函式,這種現象尤為明顯,將導致模型收斂停滯不前,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/36673.html
標籤:其他
上一篇:ImportError: cannot import name '_path' from 'matplotlib' ,如何解決?
