大部分使用 tensorflow 的同學會使用 fit() 或者 fit_generator() 方法訓練模型, 這兩個 api 對于剛接觸深度學習的同學非常友好和方便,但是由于其是非常深度的封裝,對于希望自定義訓練程序的同學就顯得不是那么方便,而且,對于 GAN 這種需要分步進行訓練的模型,也無法直接使用 fit 或者 fit_generator 直接訓練的,因此,tensorflow 提供了 train_on_batch 這個 api,對一個 mini-batch 的資料進行梯度更新,
總結優點如下:
- 更精細自定義訓練程序,更精準的收集 loss 和 metrics
- 分步訓練模型-GAN的實作
- 多GPU訓練保存模型更加方便
- 更多樣的資料加載方式
函式原型:
y_pred = Model.train_on_batch(
x,
y=None,
sample_weight=None,
class_weight=None,
reset_metrics=True,
return_dict=False,
)
官方檔案:train_on_batch
引數詳解:
x:模型輸入,單輸入就是一個 numpy 陣列, 多輸入就是 numpy 陣列的串列y:標簽,單輸出模型就是一個 numpy 陣列, 多輸出模型就是 numpy 陣列串列sample_weight:mini-batch 中每個樣本對應的權重,形狀為 (batch_size)class_weight:類別權重,作用于損失函式,為各個類別的損失添加權重,主要用于類別不平衡的情況, 形狀為 (num_classes)reset_metrics:默認True,回傳的metrics只針對這個mini-batch, 如果False,metrics 會跨批次累積return_dict:默認 False, y_pred 為一個串列,如果 True 則 y_pred 是一個字典
實體:
- 單輸出模型,且只有loss,沒有metrics, 此時 y_pred 為一個標量,代表這個 mini-batch 的 loss, 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
history = model.train_on_batch(x=image,y=label)
# history 為標量
- 單輸出模型,既有loss,也有metrics, 此時 y_pred 為一個串列,代表這個 mini-batch 的 loss 和 metrics, 串列長度為 1+len(metrics), 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
history = model.train_on_batch(x=image,y=label) # len(history ) == 2
# history 為長度為2的串列,
# history [0]為loss,
# history [1]為accuracy
- 多輸出模型,既有loss,也有metrics, 此時 y_pred 為一個串列,串列長度為 1+len(loss)+len(metrics), 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam, loss=['binary_crossentropy', 'binary_crossentropy'],
metrics=['accuracy', 'accuracy'])
history = model.train_on_batch(x=image,y=label) # len(history ) == 5
# history [0]為總loss(按照loss_weights加權),
# history [1]為第一個輸出的loss,
# history [2]為第二個輸出的loss
# history [3]為第一個accuracy,
# history [4]為第二個accuracy
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/292625.html
標籤:AI
