torch.matmul似乎沒有nn.Module包裝器來允許按名稱進行標準前向掛鉤注冊。在這種情況下,矩陣乘法發生在forward()函式的中間。我想forward()除了最終結果之外還可以回傳中間結果,例如return x, mm_res. 但是收集這些額外輸出的好方法是什么?
卸載torch.matmul輸出有哪些選項?TIA。
uj5u.com熱心網友回復:
如果您的主要抱怨是torch.matmul沒有 Module 包裝器這一事實,那么只制作一個怎么樣
class Matmul(nn.Module):
def forward(self, *args):
return torch.matmul(*args)
現在您可以在Matmul實體上注冊前向鉤子
class Network(nn.Module):
def __init__(self, ...):
self.matmul = Matmul()
self.matmul.register_module_forward_hook(...)
def forward(self, x):
y = ...
z = self.matmul(x, y)
...
話雖如此,您一定不能忽視檔案中的警告(紅色),它只能用于除錯目的。
轉載請註明出處,本文鏈接:https://www.uj5u.com/gongcheng/385702.html
