我想在另外兩個已經訓練和測驗過的神經網路的幫助下訓練一個神經網路。我要訓練的網路的輸入同時輸入到第一個靜態網路。我要訓練的網路的輸出被輸入到第二個靜態網路。損失應在靜態網路的輸出上計算并傳播回訓練網路。
# Initialization
var_model_statemapper = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 8)])
var_model_panda = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 27)])
var_model_panda.load_state_dict(torch.load("panda.pth"))
var_model_ur5 = NeuralNetwork(8, [('linear', 8), ('relu', None), ('dropout', 0.2), ('linear', 24)])
var_model_ur5.load_state_dict(torch.load("ur5.pth"))
var_loss_function = torch.nn.MSELoss()
var_optimizer = torch.optim.Adam(var_model_statemapper.parameters(), lr=0.001)
# Forward Propagation
var_panda_output = var_model_panda(var_statemapper_input)
var_ur5_output = var_model_ur5(var_statemapper_output)
var_train_loss = var_loss_function(var_panda_output, var_ur5_output)
# Backward Propagation
var_optimizer.zero_grad()
var_train_loss.backward()
var_optimizer.step()
您可以看到“var_model_statemapper”是要訓練的網路。網路“var_model_panda”和“var_model_ur5”被初始化,它們的state_dicts是從相應的“.pth”檔案中讀取的,所以這些網路需要是靜態的。我的主要問題是,哪些網路在反向傳播中更新?只是“var_model_statemapper”還是所有網路?如果“var_model_statemapper”沒有更新,我該如何實作呢?PyTorch 是否僅從優化器的初始化中知道要更新哪個網路?
uj5u.com熱心網友回復:
正式化您的管道以更好地了解設定:
x --- | state_mapper | --> y --- | ur5 | --> ur5_out
\ |
\ ↓
\--- | panda | --> panda_out ----------- | loss_fn | --> loss
這是您提供的行發生的情況:
var_optimizer.zero_grad() # 0.
var_train_loss.backward() # 1.
var_optimizer.step() # 2.
呼叫
zero_grad優化器將清除該優化器中包含的所有引數梯度的快取。在您的情況下,您已經var_optimizer注冊了來自var_model_statemapper(您要優化的模型)的引數。當您通過
backward呼叫推斷損失并對其進行反向傳播時,梯度將通過所有三個模型的引數傳播。然后呼叫
step優化器將更新在你呼叫它的優化器中注冊的引數。在您的情況下,這意味著將使用步驟1中計算的梯度單獨var_optimizer.step()更新模型的所有引數。 (即使用on 呼叫)。var_model_statemapperbackwardvar_train_loss
總而言之,您當前的方法只會更新var_model_statemapper. 理想情況下,您可以凍結模型var_model_panda并將var_model_ur5其引數的requires_grad標志設定為False. 這將節省推理和訓練的速度,因為在反向傳播期間不會計算和存盤它們的梯度。
轉載請註明出處,本文鏈接:https://www.uj5u.com/shujuku/422914.html
標籤:
上一篇:為什么我們在將資料拆分為測驗和訓練之前洗掉目標/標簽?
下一篇:vgg16層的輸出沒有意義
