本節我們將會看到如何保存模型狀態、加載和運行模型預測
import torch
import torchvision.models as models
保存和加載模型權重
PyTorch模型在一個稱為 state_dict 的內部狀態字典內保存了學習的引數,可以通過 torch.save實作這一程序,
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
為了加載模型引數,你需要首先創建一個相同模型的物體,然后使用 load_state_dict()加載引數,
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
注意:在推理前,確保呼叫 model.eval() 設定dropout和batch normalization層是評估模式,否則將產生不一致的推斷結果,
使用Shapes保存和加載模型
當加載模型權重時,我們需要首先初始化模型類,因為該類定義了網路結構,我們可能想將模型權重和該類的結構保存在一起,在這種情況下,可以將 model (而不是model.state_dict())傳入保存函式,
torch.save(model, 'model.pth')
加載
model = torch.load('model.pth')
注意:這種方法在序列化模型時使用Python pickle模塊,因此,它依賴于加載模型時可用的實際類的定義,
相關教程
Saving and Loading a General Checkpoint in PyTorch
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/423657.html
標籤:其他
上一篇:人人都是 Serverless 架構師 | “盲盒抽獎”創意營銷活動實踐
下一篇:花生殼內網穿透
