0605-優化器
目錄- 一、優化器概述
- 二、針對不同的網路設定不同的 lr
- 三、針對不同的層設定不同的 lr
- 四、動態修改 lr
pytorch完整教程目錄:https://www.cnblogs.com/nickchen121/p/14662511.html
一、優化器概述
torch 把深度學習中常用的優化方法都存盤在 torch.optim 中,它的設計十分靈活,可以很方便的擴展成自定義的優化方法,
所有的優化方法都繼承基類 optim.Optimizer,并實作了自己的優化步驟,接下來我們將以最基本的優化方法——隨機梯度下降法(SGD)距離說明,在這里需要重點掌握以下三個方法:
- 優化方法的基本使用方法
- 如何對模型的不同部分設定不同的學習率(lr)
- 如何調整學習率
import torch as t
from torch import nn
from torch.autograd import Variable as V
# 首先定義一個 LeNet 網路
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(nn.Conv2d(3, 6, 5), nn.ReLU(),
nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5),
nn.ReLU(), nn.MaxPool2d(2, 2))
self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, 10))
def forward(self, x):
x = self.features(x)
x = x.view(-1, 16 * 5 * 5)
x = self.classifier(x)
return x
net = Net()
from torch import optim
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # 梯度清零,等價于 net.zero_grad()
inp = V(t.randn(1, 3, 32, 32))
output = net(inp)
output.backward(output)
optimizer.step() # 優化引數
二、針對不同的網路設定不同的 lr
# 為不同子網路設定不同的學習率,在 finetune 中經常用到
# 如果對某個引數不指定學習率,就使用默認學習率
ptimizer = optim.SGD(
[
{
'params': net.features.parameters()
}, # 學習率為 1e-5
{
'params': net.classifier.parameters(),
'lr': 1e-2
}
],
lr=1e-5)
三、針對不同的層設定不同的 lr
# 只為兩個全連接層設定較大的學習率,其余層的學習率較小
special_layers = nn.ModuleList([net.classifier[0], net.classifier[3]])
special_layers_params = list(map(id, special_layers.parameters())) # 得到特殊層的 id
# 篩選出不屬于特殊層的層
base_params = filter(lambda p: id(p) not in special_layers_params,
net.parameters())
# 對于特殊層和非特殊層設定不同的 lr
optimizer = t.optim.SGD([{
'params': base_params
}, {
'params': special_layers.parameters(),
'lr': 0.01
}],
lr=0.001)
四、動態修改 lr
在跑代碼的程序中,我們可能需要中途改變學習率的大小,在 torch 中提供了兩種做法:
- 直接修改 optimizer.parm_groups 中對應的學習率(不推薦)
- 由于 optimizer 十分輕量級,開銷很小,因此可以新建優化器(推薦)
如果使用第二種方法新建一個優化器,在這個程序中新建的優化器會初始化動量等狀態資訊,這對使用動量的優化器來說(如自帶 momentum 的 sgd),可能會造成損失函式在收斂程序中震蕩,
# 調整學習率,新建一個 optimizer
old_lr = 0.1
optimizer = optim.SGD([{
'params': net.features.parameters()
}, {
'params': net.classifier.parameters(),
'lr': old_lr * 0.1
}],
lr=1e-5)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/280178.html
標籤:其他
