本篇博客將介紹神經網路訓練程序中的三個必備技能:使用預訓練權重、凍結訓練和斷點恢復,巧妙運用這三個技巧可以很有效地提高網路的訓練效率和效果,
文章目錄
- 一、引言
- 二、使用預訓練權重
- 三、凍結訓練
- 四、斷點恢復
- 五、預訓練和微調
一、引言
If I have seen further, it is by standing on the shoulders of giants.
遷移學習在計算機視覺領域中是一種很流行的方法,因為它可以建立精確的模型,耗時更短,利用遷移學習,不是從零開始學習,而是從之前解決各種問題時學到的模式開始,這樣,我們就可以利用以前的學習成果,避免從零開始,
二、使用預訓練權重
在計算機視覺領域中,遷移學習通常是通過使用預訓練模型來表示的,預訓練模型是在大型基準資料集上訓練的模型,用于解決相似的問題,由于訓練這種模型的計算成本較高,因此,匯入已發布的成果并使用相應的模型是比較常見的做法,例如,在目標檢測任務中,首先要利用主干神經網路進行特征提取,這里使用的backbone一般就是VGG、ResNet等神經網路,因此在訓練一個目標檢測模型時,可以使用這些神經網路的預訓練權重來將backbone的引數初始化,這樣在一開始就能提取到比較有效的特征,
可能大家會有疑問,預訓練權重是針對他們資料集訓練得到的,如果是訓練自己的資料集還能用嗎?預訓練權重對于不同的資料集是通用的,因為特征是通用的,一般來講,從0開始訓練效果會很差,因為權值太過隨機,特征提取效果不明顯,對于目標檢測模型來說,一般不從0開始訓練,至少會使用主干部分的權值,雖然有些論文提到了可以不用預訓練,但這主要是因為他們的資料集比較大而且他們的調參能力很強,如果從0開始訓練,網路在前幾個epoch的Loss可能會非常大,并且多次訓練得到的訓練結果可能相差很大,因為權重初始化太過隨機,
PyTorch提供了state_dict()和load_state_dict()兩個方法用來保存和加載模型引數,前者將模型引數保存為字典形式,后者將字典形式的模型引數載入到模型當中,下面是使用預訓練權重(加載預訓練模型)的代碼,其中model_path就是預訓練權重檔案的路徑:
# 第一步:讀取當前模型引數
model_dict = model.state_dict()
# 第二步:讀取預訓練模型
pretrained_dict = torch.load(model_path, map_location = device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
# 第三步:使用預訓練的模型更新當前模型引數
model_dict.update(pretrained_dict)
# 第四步:加載模型引數
model.load_state_dict(model_dict)
但是,使用load_state_dict()加載模型引數時,要求保存的模型引數鍵值型別和模型完全一致,一旦我們對模型結構做了些許修改,就會出現類似unexpected key module.xxx.weight問題,比如在目標檢測模型中,如果修改了主干特征提取網路,只要不是直接替換為現有的其它神經網路,基本上預訓練權重是不能用的,要么就自己判斷權值里卷積核的shape然后去匹配,要么就只能利用這個主干網路在諸如ImageNet這樣的資料集上訓練一個自己的預訓練模型;如果修改的是后面的neck或者是head的話,前面的backbone的預訓練權重還是可以用的,下面是權值匹配的示例代碼,把不匹配的直接pass了:
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
temp = {}
for k, v in pretrained_dict.items():
try:
if np.shape(model_dict[k]) == np.shape(v):
temp[k]=v
except:
pass
model_dict.update(temp)
model.load_state_dict(model_dict)
三、凍結訓練
凍結訓練其實也是遷移學習的思想,在目標檢測任務中用得十分廣泛,因為目標檢測模型里,主干特征提取部分所提取到的特征是通用的,把backbone凍結起來訓練可以加快訓練效率,也可以防止權值被破壞,在凍結階段,模型的主干被凍結了,特征提取網路不發生改變,占用的顯存較小,僅對網路進行微調,在解凍階段,模型的主干不被凍結了,特征提取網路會發生改變,占用的顯存較大,網路所有的引數都會發生改變,舉個例子,如果在解凍階段設定batch_size為4,那么在凍結階段有可能可以把batch_size設定到8,下面是進行凍結訓練的示例代碼,假設前50個epoch凍結,后50個epoch解凍:
# 凍結階段訓練引數,learning_rate和batch_size可以設定大一點
Init_Epoch = 0
Freeze_Epoch = 50
Freeze_batch_size = 8
Freeze_lr = 1e-3
# 解凍階段訓練引數,learning_rate和batch_size設定小一點
UnFreeze_Epoch = 100
Unfreeze_batch_size = 4
Unfreeze_lr = 1e-4
# 可以加一個變數控制是否進行凍結訓練
Freeze_Train = True
# 凍結一部分進行訓練
batch_size = Freeze_batch_size
lr = Freeze_lr
start_epoch = Init_Epoch
end_epoch = Freeze_Epoch
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = False
# 解凍后訓練
batch_size = Unfreeze_batch_size
lr = Unfreeze_lr
start_epoch = Freeze_Epoch
end_epoch = UnFreeze_Epoch
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = True
如果不進行凍結訓練,一定要注意引數設定,注意上述代碼中凍結階段和解凍階段的learning_rate和batch_size是不一樣的,另外起始epoch和結束epoch也要重新調整一下,如果是從0開始訓練模型(不使用預訓練權重),那么一定不能進行凍結訓練,
四、斷點恢復
在上面凍結訓練和解凍訓練的代碼里設定了不同的batch_size,前者是8后者是4,有可能凍結訓練的時候顯存是夠用的,結果解凍后顯存不足了,這個時候需要重新把解凍訓練階段的batch_size調得更小一點,但是網路才訓練了凍結階段的50個epoch,backbone引數還是用的預訓練權重呢,網路效果肯定不夠好,難道要前功盡棄重新開始訓練?這時候就要使用斷點恢復技術了,其實斷點恢復的思想很簡單,就是把網路初始設定的model_path改為出錯前保存好的權值檔案,然后調整一下起始epoch和終止epoch即可,比如在前面提到的這種情況里,在第51個epoch報了錯,那么可以把model_path修改為第50個epoch訓練結束后保存的權值檔案,然后把起始epoch調整成50就可以了,
斷點恢復的應用范圍非常非常廣,最常見的情況就是代碼跑到一半因為某些原因中斷了(比如電腦突然死機重啟這種不可抗力因素),又不想從頭重新跑,那么就可以利用斷點恢復訓練的方法,這樣可以節省不少時間,再比如,一個非常常見的情況,假如一開始設定了100個epoch,結果模型訓練結束時,Loss還呈現下降的趨勢,也就是模型還沒有收斂,這種現象有可能就是epoch設定小了,所以可以把第100個epoch訓練得到的權值檔案當做初始權值檔案再訓練幾個epoch看看,避免重新設定epoch從頭訓練,
當然,想要執行斷點恢復首先需要把每個epoch得到的權值檔案保存起來,這樣才能修改model_path重新加載,斷點恢復和常規的模型保存加載的區別其實就是epoch也要修改一下而已,保存權重可以用以下方法:
torch.save(model.state_dict(), "你要保存到的路徑")
五、預訓練和微調
最后再來總結一下預訓練和微調,這是兩個非常重要的概念,其實也很好理解,舉個栗子是最能直觀理解的,
假如我們現在要搭建一個網路模型來完成一個影像分類的任務,首先我們需要把網路的引數進行初始化,然后在訓練網路的程序中不斷對引數進行調整,直到網路的損失越來越小,在訓練程序中,一開始初始化的引數會不斷變化,如果結果已經滿意了,那我們就可以把訓練好的模型引數保存下來,以便訓練好的模型可以在下次執行類似任務的時候獲得比較好的效果,這個程序就是預訓練(Pre-Training),
假如在完成上面的模型訓練后,我們又接到另一個類似的影像分類任務,這時我們就可以直接使用之前保存下來的模型引數作為這一次任務的初始化引數,然后在訓練程序中依據結果不斷進行修改,這個程序就是微調(Fine-Tuning),
我們使用的神經網路越深,就需要越多的樣本來進行訓練,否則就很容易出現過擬合現象,比如我們想訓練一個識別貓的模型,但是自己標注資料精力有限只標了100張,這時就可以考慮ImageNet資料集,可以在ImageNet上訓練一個模型,然后使用該模型作為類似任務的初始化或特征提取器,這樣既節省了時間和計算資源,又能很快地達到較好的效果,當然,采用預訓練+微調也不是絕對有效的,上面識別貓的例子可以這樣做是因為ImageNet里有貓的影像,所以可以認為是一個類似的資料集,如果是識別癌細胞的話,效果可能就不是那么好了,關于預訓練和微調是有很多策略的,經驗也很重要,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/385484.html
標籤:AI
