最簡單的情況
模型保存:
torch.save(model.state_dict(), PATH)
模型加載:
model.load_state_dict(torch.load(PATH))
此時保存的是一個字典,key為model中的weight或bias名,如"linear1.weight"或“linear2.bias”
有時我們使用了優化器
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
我們在保存引數時需要同時保存優化器中的引數:
save_state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict()}
torch.save(save_state, PATH)
在加載時,
model=MyModel()
model.load_state_dict(torch.load("PATH")['net'])
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
optimizer.load_state_dict(torch.load("lab3_lstmtest_0614.pth")['optimizer'])
這樣即保存和加載了模型和優化器引數,繼續上一次訓練,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/18202.html
標籤:其他
