當我運行下面的代碼時:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
報告了一個 RuntimeError:
RuntimeError: result type Float can't be cast to the desired output type Long
但是當我更改params[var] *= 0.1為時params[var] = params[var] * 0.1,錯誤消失了。
為什么會發生這種情況?
我以為params[var] *= 0.1和 效果一樣params[var] = params[var] * 0.1。
uj5u.com熱心網友回復:
首先,讓我們知道 中的第一個 long 型別引數,如果模型中有 BatchNormalization 層densenet201,您會發現它表示訓練期間用于計算均值和方差的 mini-batch 的數量。.features.norm0.num_batches_trackedThis parameter is a long-type number and cannot be float type because it behaves like a counter
其次,在 PyTorch 中,有兩種型別的操作:
- 非就地操作:您將計算后的新輸出分配給變數的新副本,例如 x = x 1 或 x = x / 2。分配前 x 的記憶體位置不等于分配后的記憶體位置,因為您有原始變數的副本。
- 就地操作:當計算直接應用于變數的原始副本而不在此處進行任何復制時,例如 x = 1 或 x /= 2。
讓我們轉到您的示例以了解發生了什么:
非Inplcae操作:
model = torchvision.models.densenet201(num_classes=10) params = model.state_dict() name = 'features.norm0.num_batches_tracked' print(id(params[name])) # 140247785908560 params[name] = params[name] 0.1 print(id(params[name])) # 140247785908368 print(params[name].type()) # changed to torch.FloatTensor就地操作:
print(id(params[name])) # 140247785908560 params[name] = 1 print(id(params[name])) # 140247785908560 print(params[name].type()) # still torch.LongTensor params[name] = 0.1 # you want to change the original copy type to float ,you got an error
最后,幾點說明:
- 就地操作可以節省一些記憶體,但在計算導數時可能會出現問題,因為會立即丟失歷史記錄。因此,不鼓勵使用它們。資源
- 當您決定使用就地操作時應該謹慎,因為它們會覆寫原始內容。
- 如果你使用 pandas,這有點類似于
inplace=Truepandas 中的 :)。
這是閱讀有關就地操作源的更多資訊并閱讀此討論源的好資源。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/442578.html
上一篇:計算多元線性回歸的交叉驗證
