僅當監控值大于閾值時,如何激活 keras.EarlyStopping。例如,如何earlystop = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=5, verbose=1, mode='auto')僅在 val 精度 > 0.9 時觸發?另外,我應該如何正確匯出中間模型,例如每 50 個時期?
我沒有太多的知識,并且 EarlyStopping 的基線引數似乎意味著閾值以外的其他東西。
uj5u.com熱心網友回復:
停止指標閾值的最佳方法是使用 Keras 自定義回呼。下面是一個自定義回呼的代碼(SOMT - 在度量閾值上停止),它將完成這項作業。SOMT 回呼對于根據訓練準確度或驗證準確度或兩者的值結束訓練很有用。使用形式為 callbacks=[SOMT(model, train_thold, valid_thold)] 其中
- model 是您編譯的模型的名稱
- train_thold 是一個浮點數。這是模型必須達到的準確度值(以百分比為單位)才能有條件地停止訓練
- valid_threshold 是一個浮點數。為了有條件地停止訓練,模型必須達到驗證準確率的值(以百分比為單位)
注意要停止訓練,必須在同一時期超過 train_thold 和 valid_thold。
如果您想僅根據訓練精度停止訓練,請將 valid_thold 設定為 0.0。
同樣,如果您想停止僅驗證準確度集 train_thold= 0.0 的訓練。
注意,如果兩個閾值都沒有在同一個 epoch 中達到,訓練將持續到 epochs 的值。如果在同一時期達到兩個閾值,則停止訓練并將您的模型權重設定為該時期的權重。
例如,假設您想在
訓練準確度達到或超過 95% 并且驗證準確度達到至少 85%時停止訓練,
那么代碼將是 callbacks=[SOMT(my_model, .95, .85 )]
# the callback uses the time module so
import time
class SOMT(keras.callbacks.Callback):
def __init__(self, model, train_thold, valid_thold):
super(SOMT, self).__init__()
self.model=model
self.train_thold=train_thold
self.valid_thold=valid_thold
def on_train_begin(self, logs=None):
print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
print ('and validation accuracy meets or exceeds ', self.valid_thold)
msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
print (msg)
def on_train_batch_end(self, batch, logs=None):
acc=logs.get('accuracy')* 100 # get training accuracy
loss=logs.get('loss')
msg='{0:1s}processed batch {1:4s} training accuracy= {2:8.3f} loss: {3:8.5f}'.format(' ', str(batch), acc, loss)
print(msg, '\r', end='') # prints over on the same line to show running batch count
def on_epoch_begin(self,epoch, logs=None):
self.now= time.time()
def on_epoch_end(self,epoch, logs=None):
later=time.time()
duration=later-self.now
tacc=logs.get('accuracy')
vacc=logs.get('val_accuracy')
tr_loss=logs.get('loss')
v_loss=logs.get('val_loss')
ep=epoch 1
print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
if tacc>= self.train_thold and vacc>= self.valid_thold:
print( f'\ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch 1}' )
self.model.stop_training = True # stop training
Note include this code after compiling your model and prior to fitting your model
train_thold= .98
valid_thold=.95
callbacks=[SOMT(model, train_thold, valid_thold)]
# training will halt if train accuracy meets or exceeds train_thold
# AND validation accuracy meets or exceeds valid_thold in the SAME epoch
In model.fit include callbacks=callbacks, verbose=0. At the end of each epoch the callback produces a spreadsheet like printout of the form
Epoch Train Acc Train Loss Valid Acc Valid_Loss Duration
1 0.90 4.3578 0.95 2.3982 84.16
2 0.95 1.6816 0.96 1.1039 63.13
3 0.97 0.7794 0.95 0.5765 63.40
training accuracy and validation accuracy reached the thresholds on epoch 3.
轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/456247.html
