我正在做一個旨在預測下一個角色的專案。預測的字符 Y 被映射到 One Hot 編碼。
| 資料 | 形狀 |
|---|---|
| 一 | (N,L,60) |
| 目標資料 | (N,60) |
這是我的代碼:
HSIZE = 128
model = Sequential()
model.add(SimpleRNN(HSIZE, return_sequences=False, input_shape=(SEQLEN, nb_chars), unroll=True, use_bias=True))
這是列印的引數數量model.summary():

我選擇了一個隱藏大小h = 128,我想知道如何計算引數的數量。我嘗試手動完成,但找不到model.summary().
我知道有三個矩陣U(用于輸入)、W(用于隱藏狀態)和V(用于輸出), 偏差,我最終得到:
Dim(U) Dim(W) Dim(V) = (128,60) (128,128) (60,128) (bias)60 128 128 = 32060。
有什么想法或者你有沒有發現我身邊的任何潛在誤解?
uj5u.com熱心網友回復:
所述SimpleRNN層(同樣適用于LSTM和GRU層,并且還使用RNN沿與相應的Cell類)不不包括輸出轉換。您實際上可以通過摘要列出 128 個單位(狀態大小)的輸出形狀來猜測它。它只計算狀態序列。
因此引數的數量很簡單128*128 60*128 128 = 24192(隱藏到隱藏矩陣,輸入到隱藏矩陣,偏差)。
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/391930.html
