我計劃從 pth 檔案加載重量,例如,
model = my_model()
model.load_state_dict(torch.load("../input/checkpoint/checkpoint.pth")
但是,這里有一個錯誤,說:
RuntimeError: Error(s) in loading state_dict for my_model:
Missing key(s) in state_dict: "att.in_proj_weight", "att.in_proj_bias", "att.out_proj.weight", "att.out_proj.bias".
Unexpected key(s) in state_dict: "in_proj_weight", "in_proj_bias", "out_proj.weight", "out_proj.bias".
似乎我的模型的引數名稱與存盤在state_dict. 在這種情況下,我應該如何使它們保持一致?
uj5u.com熱心網友回復:
您可以創建新字典并修改沒有att.前綴的鍵,您可以將新字典加載到您的模型中,如下所示:
state_dict = torch.load('path\to\checkpoint.pth')
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in state_dict.items():
key = key[4:] # remove `att.`
new_state_dict[key] = value
# load params
model = my_model()
model.load_state_dict(new_state_dict)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/381664.html
上一篇:如何使用python計算這些權重和偏差的前向傳遞輸出
下一篇:如何在一系列日期上訓練LSTM?
