我基于 GRU 訓練以下模型,注意我將引數傳遞stateful=True給 GRU 構建器。
class LearningToSurpriseModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
stateful=True,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
[defining here my training step]
我實體化我的模型
model = LearningToSurpriseModel(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units
)
[編譯并做事] 下面的自定義回呼在每個 epoch 結束時手動重置狀態。
gru_layer = model.layers[1]
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, gru_layer):
self.gru_layer = gru_layer
def on_epoch_end(self, epoch, logs=None):
self.gru_layer.reset_states()
model.fit(train_dataset, validation_data=validation_dataset, \
epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)
狀態將重置為零。我想按照https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html中的想法將狀態(重新)初始化為隨機噪聲。什么是隨機噪聲的好實作?
我應該覆寫reset_states()添加states引數嗎?
uj5u.com熱心網友回復:
您可以嘗試使用tf.random.normal:
self.gru_layer.reset_states(tf.random.normal((batch_size, rnn_units)))
要么
self.gru_layer.reset_states(tf.random.uniform((batch_size, rnn_units)))
所以,你Callback可能看起來像這樣:
import tensorflow as tf
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, gru_layer, batch_size, dims):
self.gru_layer = gru_layer
self.batch_size = batch_size
self.dims = dims
def on_epoch_end(self, epoch, logs=None):
self.gru_layer.reset_states(tf.random.normal((self.batch_size, self.dims)))
轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/432310.html
