既然已經有模型和資料了,是時候在資料上優化模型引數來訓練、驗證和測驗它了,模型訓練是一個迭代程序;在每一次迭代(epoch),模型會作出一個預測,計算其預測誤差(loss),收集誤差關于模型引數的導數(如前一節所述),并使用梯度優化這些引數,關于這一程序的詳細資訊,可以觀看backpropagation from 3Blue1Brown,
先決代碼
我們從Datasets & DataLoaders和Build Model復制了這些代碼:
import torch
from torch import import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = https://www.cnblogs.com/DeepRS/archive/2022/02/09/datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = https://www.cnblogs.com/DeepRS/archive/2022/02/09/datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = Dataloader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512)
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
輸出:
點擊查看代碼
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
超引數
超引數是可調引數,用于控制模型優化行程,不同的超引數值可能會影響模型訓練和收斂速度(閱讀更多關于超引數微調的資訊),
我們為訓練定義了以下超引數:
- Epoch - 迭代資料集的次數
- Batch Size - 更新引數前,通過網路傳播的資料樣本數
- Learning Rate -每次 batch/epoch,更新模型的程度,較小的值會導致學習速度較慢,而較大的值可能會導致訓練程序中不可預測的行為,
learning_rate = 1e-3
batch_size = 64
epochs = 5
優化回圈
一旦我們設定好超引數,就可以通過一個optimization loop來訓練和優化網路,每次optimization loop的迭代稱為一個epoch,
每個epoch包含兩部分:
- 訓練Loop - 迭代訓練集,嘗試收斂到最佳引數
- 驗證\測驗Loop - 迭代測驗集,檢查模型性能是否提高,
Loss Function
給定一些資料,未經訓練的網路可能不會給出正確答案,Loss function衡量了所獲結果和目標值的不同程度,訓練時正是要最小化損失函式,為了計算loss我們使用給定樣本對的輸入作出預測,并與其真實標簽做對比,
常見的損失函式包括用于回歸任務nn.MSELoss(Mean Square Error),用于分類的nn.NLLLoss,nn.CrossEntropyLoss結合了 nn.LogSoftmax 和 nn.NLLLoss,
將模型輸出的logist傳入 nn.CrossEntropyLoss, 該函式將標準化logits并計算預測誤差,
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
優化器
優化是每次訓練時調整模型引數,減少模型誤差的程序,優化演算法定義了該程序是如何實作的(該例中我們使用了Stochastic Gradient Descent隨機梯度下降),所有的優化邏輯都被封裝在了 optimizer 物件,在這里,我們使用SGD優化器;此外,在PyTorch中還有許多不同的優化器,例如ADAM和RMSProp,對不同型別的模型和資料都很有效,
我們通過注冊需要訓練的模型引數來初始化優化器,并傳入學習率超引數,
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
在訓練回圈中,優化分為三個步驟:
- 呼叫
optimizer.zero_grad()重置模型引數的梯度,默認情況下梯度相加,為防止重復計數,我們在每次迭代時顯示地將它們歸零, - 呼叫loss.backwards()反向傳播預測誤差,PyTorch計算loss關于每個引數的梯度,
- 呼叫
optimizer.step(),通過在反向傳播中得到的梯度調整引數,
完整實作
我們定義了 train_loop 回圈迭代optimization代碼,test_loop 評估模型在測驗集上的性能,
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
初始化損失函式和優化器,傳入 train_loop 和 test_loop,隨意增加epoch,以跟蹤模型不斷改進的性能,
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
輸出:
點擊查看代碼
Epoch 1
-------------------------------
loss: 2.290156 [ 0/60000]
loss: 2.275099 [ 6400/60000]
loss: 2.256799 [12800/60000]
loss: 2.252760 [19200/60000]
loss: 2.235528 [25600/60000]
loss: 2.205756 [32000/60000]
loss: 2.204928 [38400/60000]
loss: 2.172354 [44800/60000]
loss: 2.160271 [51200/60000]
loss: 2.127511 [57600/60000]
Test Error:
Accuracy: 49.9%, Avg loss: 2.116347
Epoch 2
-------------------------------
loss: 2.124757 [ 0/60000]
loss: 2.107859 [ 6400/60000]
loss: 2.045332 [12800/60000]
loss: 2.061512 [19200/60000]
loss: 2.002954 [25600/60000]
loss: 1.940844 [32000/60000]
loss: 1.962774 [38400/60000]
loss: 1.874285 [44800/60000]
loss: 1.875532 [51200/60000]
loss: 1.802694 [57600/60000]
Test Error:
Accuracy: 58.7%, Avg loss: 1.794751
Epoch 3
-------------------------------
loss: 1.830118 [ 0/60000]
loss: 1.797928 [ 6400/60000]
loss: 1.670504 [12800/60000]
loss: 1.718298 [19200/60000]
loss: 1.605203 [25600/60000]
loss: 1.560042 [32000/60000]
loss: 1.583883 [38400/60000]
loss: 1.483568 [44800/60000]
loss: 1.515428 [51200/60000]
loss: 1.414553 [57600/60000]
Test Error:
Accuracy: 62.0%, Avg loss: 1.430290
Epoch 4
-------------------------------
loss: 1.499763 [ 0/60000]
loss: 1.472005 [ 6400/60000]
loss: 1.319050 [12800/60000]
loss: 1.399100 [19200/60000]
loss: 1.283040 [25600/60000]
loss: 1.279892 [32000/60000]
loss: 1.300507 [38400/60000]
loss: 1.221794 [44800/60000]
loss: 1.262865 [51200/60000]
loss: 1.173478 [57600/60000]
Test Error:
Accuracy: 63.9%, Avg loss: 1.193923
Epoch 5
-------------------------------
loss: 1.268049 [ 0/60000]
loss: 1.260393 [ 6400/60000]
loss: 1.092561 [12800/60000]
loss: 1.205449 [19200/60000]
loss: 1.083632 [25600/60000]
loss: 1.101792 [32000/60000]
loss: 1.134809 [38400/60000]
loss: 1.062815 [44800/60000]
loss: 1.108174 [51200/60000]
loss: 1.035161 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.049588
Epoch 6
-------------------------------
loss: 1.114492 [ 0/60000]
loss: 1.130664 [ 6400/60000]
loss: 0.944653 [12800/60000]
loss: 1.083935 [19200/60000]
loss: 0.961972 [25600/60000]
loss: 0.981254 [32000/60000]
loss: 1.033072 [38400/60000]
loss: 0.961604 [44800/60000]
loss: 1.007507 [51200/60000]
loss: 0.948494 [57600/60000]
Test Error:
Accuracy: 66.0%, Avg loss: 0.956025
Epoch 7
-------------------------------
loss: 1.006542 [ 0/60000]
loss: 1.046684 [ 6400/60000]
loss: 0.842564 [12800/60000]
loss: 1.002121 [19200/60000]
loss: 0.884486 [25600/60000]
loss: 0.895794 [32000/60000]
loss: 0.965427 [38400/60000]
loss: 0.895181 [44800/60000]
loss: 0.937755 [51200/60000]
loss: 0.889426 [57600/60000]
Test Error:
Accuracy: 67.3%, Avg loss: 0.891673
Epoch 8
-------------------------------
loss: 0.926312 [ 0/60000]
loss: 0.987333 [ 6400/60000]
loss: 0.768049 [12800/60000]
loss: 0.943189 [19200/60000]
loss: 0.831892 [25600/60000]
loss: 0.833098 [32000/60000]
loss: 0.916814 [38400/60000]
loss: 0.850216 [44800/60000]
loss: 0.887719 [51200/60000]
loss: 0.846100 [57600/60000]
Test Error:
Accuracy: 68.5%, Avg loss: 0.844885
Epoch 9
-------------------------------
loss: 0.864126 [ 0/60000]
loss: 0.941802 [ 6400/60000]
loss: 0.711602 [12800/60000]
loss: 0.898299 [19200/60000]
loss: 0.793915 [25600/60000]
loss: 0.786041 [32000/60000]
loss: 0.879356 [38400/60000]
loss: 0.818412 [44800/60000]
loss: 0.850554 [51200/60000]
loss: 0.812724 [57600/60000]
Test Error:
Accuracy: 69.7%, Avg loss: 0.809041
Epoch 10
-------------------------------
loss: 0.814177 [ 0/60000]
loss: 0.904296 [ 6400/60000]
loss: 0.667563 [12800/60000]
loss: 0.862825 [19200/60000]
loss: 0.764706 [25600/60000]
loss: 0.750034 [32000/60000]
loss: 0.848550 [38400/60000]
loss: 0.794559 [44800/60000]
loss: 0.821466 [51200/60000]
loss: 0.785530 [57600/60000]
Test Error:
Accuracy: 70.9%, Avg loss: 0.780144
Done!
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/423583.html
標籤:其他
上一篇:AI 智能寫情詩、藏頭詩
