我目前正在嘗試添加一個功能來中斷和恢復基于此示例代碼創建的 GAN 的訓練:https : //machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-從頭開始/
我設法讓它以一種方式作業,我將整個復合 GAN 的權重保存在 summarise_performance 函式中,該函式每 10 個時期觸發一次,如下所示:
# save all weights
filename3 = 'weights_d.h5' % (step 1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))
它加載到我添加到程式開頭的一個名為 load_model 的函式中,該函式采用正常構建的 gan 架構,但將其權重更新為最新值,如下所示:
#load model from file and return startBatch number
def load_model(gan_model):
start_batch = 0
files = glob.glob("./weights_0*.h5")
if(len(files) > 0 ):
most_recent_file = files[len(files)-1]
gan_model.load_weights(most_recent_file)
#TODO: breaks if using more than 8 digits for batches
startBatch = int(most_recent_file[10:18])
if (start_batch != 0):
print("> found existing weights; starting at batch %d" % start_batch)
return start_batch
其中 start_batch 被傳遞給 train 函式以跳過已經完成的時期。
雖然這種減輕重量的方法確實“有效”,但我仍然認為我的方法是錯誤的,因為我發現權重資料顯然不包括 GAN 的優化器狀態,??因此訓練不會像它那樣繼續沒有被打斷。
我發現保存進度同時保存優化器狀態的方式顯然是通過保存整個模型而不僅僅是權重來完成的
在這里我遇到了一個問題,因為在 GAN 中,我不僅訓練了一個模型,而且有 3 個模型:
- 生成器模型 g_model
- 判別器模型 d_model
- 和復合 GAN 模型 gan_model
這些都是相互聯系和相互依賴的。如果我采用天真的方法并分別保存和恢復這些零件模型中的每一個,我最終會得到 3 個獨立的脫節模型而不是 GAN
有沒有一種方法可以讓我恢復訓練,就好像沒有發生中斷一樣,可以保存和恢復整個 GAN?
uj5u.com熱心網友回復:
如果您想恢復整個 GAN,可以考慮使用tf.train.Checkpoint:
### In your training loop
checkpoint_dir = '/checkpoints'
checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator
gan_model = gan_model)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
....
....
if (epoch 1) % 40 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch 1,ckpt_save_path))
### After x number of epochs, just save your generator model for inference.
generator.save('your_model.h5')
您也可以考慮完全擺脫復合模型。這是我的意思的一個例子。
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/337541.html
上一篇:Keras/Conv2D:奇怪了,我用padding=SAME,但是尺寸還是縮小了
下一篇:構建用于引數預測的自動編碼器網路
