訓練流程
-
從
tools/train.py開始:- 一通讀取 cfg ,初步設定一些基本引數,log 引數;
- build 模型,build 資料集 (有多少個 workflow 就 build 多少個資料集,比如如果 train 的程序中還進行 val 則表示有 2 個 workflow) ;
- 最后呼叫
mmdet.apis.train_detector,傳入剛才 build 好的 model,datasets,配置引數等,
-
進入
mmdet.apis.train_detector:- 為每一個 workflow 對應的 dataset , build data_loader ( data_loader 繼承自 pytorch 自帶的 DataLoader 類,這里先簡單理解,其將 dataset 里面 data sample 包裝成 data batch ,作為生成器的形式,每次用 for 迭代 load batch ) ;
- 判斷是否是分布式訓練,分布式訓練則用
MMDistributedDataParallel封裝 model,單 GPU 訓練則MMDataParallel; - build optimizer;
- 重頭戲: runner ,runner 可以理解為操控整個訓練程序的核心,首先,先跳過中間那一堆對 runner hook 的設定,直接看到最后,呼叫了
runner.run(),訓練從此處開始,
-
runner 是
EpochBasedRunner類的實體,進入EpochBasedRunner類的定義,可以看到最主要的是 run 方法:def run(self, data_loaders, workflow, max_epochs, **kwargs): #... while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( 'mode in workflow must be a str, but got {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break epoch_runner(data_loaders[i], **kwargs)workflow變數的注釋:workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,最后 4 行是重點,根據每個 workflow 的 mode 和 epochs 呼叫 epochs 次相應的函式,比如:
for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break #when mode == 'train' # `epoch_runner(data_loaders[i], **kwargs)` == self.train(data_loaders[i], **kwargs)一個 epoch 相當于遍歷一遍資料集的所有資料,
接下來看看 EpochBasedRunner.train() :
- 設定基本引數
- 在一些關鍵節點的前后呼叫了 hook :
before_train_epoch,before_train_iter,after_train_iter,after_train_epoch,執行反向傳播是在after_train_iter處, (先不糾結 hook 是個啥) - data_loader 為生成器,用 for 迭代取出 1 個 batch 的資料,進入逐個 iter 的訓練:
- 如果有為該 Runner 指定 batch processor,則呼叫,
- 否則,直接呼叫模型的 train_step,傳入訓練資料,
hook
hook 的作用是對一些中間結果做相應的操作,比如列印 log ,比如在 training 程序中的 evaluation 等等,
下面決議一下組態檔中出現的 TensorboardLoggerHook
先從 EpochBasedRunner 如何使用 hook 看起:
-
EpochBasedRunner.register_hook()- 注冊 hook 到 runner,根據 hook cfg build 相應的 hook 實體,放到 runner 的 hook 佇列中,hook 佇列是一個優先級佇列,優先級可以在傳入 hook 的時候指定,
-
EpochBasedRunner.call_hook(fn_name)- 使用 hook ,根據需要呼叫的函式名
fn_name,呼叫每個 hook 里的同名函式,因為 runner 快取著中間結果,需要將 runner 作為引數傳進去,
- 使用 hook ,根據需要呼叫的函式名
TensorboardLoggerHook
-
該類主要的作用是將每次 iter 或 epoch 完記錄訓練結果到 tensorboard (即寫到 summary 檔案里)
-
TensorboardLoggerHook.after_train_iter(runner)該函式做了什么?判斷是否達到 interval,比如在組態檔中指定了每 50 個 iter 才 log 訓練結果,如果達到 50 個 iter,則對 50 個 iter 的結果求平均,再呼叫自己的 log 函式, 50 個 iter 的結果存放在 runner.log_buffer 里,
-
TensorboardLoggerHook.log(runner)將 runner.log_buffer 里的結果值,通過 summary_writer 寫到 summary 檔案,
除了 Logger 這種形式的 hook 之外,還有其他一些功能也以 hook 的形式實作,比如 optimizer 對應的 OptimizerHook,或者 training 程序中的 eval 也是通過 EvaluationHook 呼叫,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/79667.html
標籤:其他
上一篇:[開源框架]mmdetection3d學習(一):初步認識
下一篇:深度學習的學習率調節實踐
