訓練網路之前有很多引數要設定,不了解各個引數的含義就沒法合理地設定引數值,訓練效果也會因此大受影響,本篇博客記錄一下網路訓練里的Batch Size、Iterations和Epochs怎么理解,也方便后續理解模型微調和凍結訓練等技巧,
文章目錄
- 一、引言
- 二、Batch Size
- 三、Iterations
- 四、Epochs
- 五、舉個栗子
一、引言
首先要了解一下為什么會出現Batch Size這個概念,深度學習演算法是迭代的,也就是會多次使用演算法獲取結果,以得到最優化的結果,每次迭代更新網路引數有兩種方式,也是兩種極端:
第一種是Batch Gradient Descent,批梯度下降,即把所有資料一次性輸入進網路,把資料集里的所有樣本都看一遍,然后計算一次損失函式并更新引數,這種方式計算量開銷很大,速度也很慢,不支持在線學習,
第二種是Stochastic Gradient Descent,隨機梯度下降,即把每次只把一個資料輸入進網路,每看一個資料就算一下損失函式并更新引數,這種方式雖然速度比較快,但是收斂性能不好,可能會在最優點附近震蕩,兩次引數的更新也有可能互相抵消掉,
可見,這兩種方式都有問題,所以現在一般都是采用兩種方式的折衷,Mini-Batch Gradient Decent,小批梯度下降,就是把資料進行切片,劃分為若干個批,按批來更新引數,這樣,一個批中的一組資料共同決定了本次梯度的方向,下降起來就不容易跑偏,減少了隨機性,并且由于批的樣本數與整個資料集相比小了很多,計算量也不是很大,
二、Batch Size
所謂的batch_size,就是每次訓練所選取的樣本數,通俗點講就是一個 batch中的樣本總數,一次喂進網路的樣本數,batch_size的選擇會影響梯度下降的方向,
在合理范圍內增大batch_size有以下幾個好處:
- 記憶體利用率高,大矩陣乘法的并行化效率提高;
- 跑完全部資料所需的迭代次數少,對于相同資料量的處理速度可以進一步加快;
- 在一定范圍內,一般來說batch_size越大,其確定的下降方向越準,引起的訓練震蕩越小,
但也不能盲目增大,否則會有以下幾個壞處:
- 記憶體容量可能撐不住,報錯RuntimeError:CUDA out of memory;
- 跑完全部資料集所需的迭代次數減少,要想達到相同的精度,其所花費的時間大大增加了,從而對引數的修正也就顯得更加緩慢;
- 當batch_size增大到一定程度時,其確定的下降方向已經基本不再變化了,
三、Iterations
所謂的iterations,就是訓練完全部資料需要迭代的次數,通俗點講一個iteration就是使用batch_size個樣本把網路訓練一次,iterations就是整個資料集被劃分成的批次數目,數值上等于data_size/batch_size,
把全部的樣本資料,按照batch_size進行切片,劃分成iterations塊,每個iteration結束后都會更新一次網路結構的引數,每一次迭代得到的結果都會被作為下一次迭代的初始值,
一個iteration=一個batch_size的資料進行一次forward propagation和一次backward propagation,
四、Epochs
所謂的epochs,就是前向傳播和反向傳播程序中所有批次的訓練迭代次數,一個epoch就是整個資料集的一次前向傳播和反向傳播,通俗點講,epochs指的就是訓練程序中全部資料將被送入網路訓練多少次,
為什么要使用多個epoch進行訓練呢?因為在神經網路中傳遞完整的資料集一次是不夠的,我們需要將完整的資料集在同樣的神經網路中傳遞多次,我們使用的是有限的資料集,僅僅更新權重一次或者說使用一個epoch是不夠的,
如果epochs太小,網路有可能發生欠擬合;如果epochs太大,則有可能發生過擬合,具體怎么選擇要根據實驗結果去判斷和選擇,對于不同資料集選取的epochs是不一樣的,
五、舉個栗子
假設有1024個訓練樣本,batch_size=8,epochs=10,那么:每個epoch會訓練1024/8=128個iteration,全部1024個訓練樣本會被這樣訓練10次,所以一共會有1280個iteration,發生1280次前向傳播和反向傳播,注意,由于Batch Normalization層的存在,batch_size一般設定為2的倍數,并且不能為1,
總結一下:
1. Batch使用訓練集中的一小部分樣本對模型權重進行一次反向傳播的引數更新,這一小部分樣本被稱為“一批資料”;
2. Iteration是使用一個Batch資料對模型進行一次引數更新的程序,被稱為“一次訓練”;
3. Epoch使用訓練集的全部資料對模型進行一次完整訓練,被稱為“一代訓練”,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/385485.html
標籤:AI
