論文:Deformable ConvNets v2: More Deformable, Better Results
論文鏈接:https://arxiv.org/abs/1811.11168
在github上有許多實作,但都沒有很方便的加入到自己的網路中,同時github上代碼很長,對于我這種懶人根本不想呼叫,突然發現torchvision.ops.deform_conv2d代碼,又苦于搜索不到具體如何使用,遂寫下記錄一下,方便其他人,同時本人學藝不精,如果有任何問題歡迎批評指正,
首先,我們需要回到論文,看如何定義Deformable_ConvNet(就直接貼圖了)

簡單來說,將feature map當作一個一個網格,其中輸出結果y中,點p這個坐標的值,取決于,其中
為權重,
為論文中引入的modulation scalar factor,而
需要根據三個引數之和作為輸入,
代表的是原坐標,
是相對于坐標
的相對位移,例如,一個
的卷積核,則
,以上都是標準卷積,本文提出Deformable_Conv就在于加入了新的引數
,即需要網路去學習的一個learnable offset,根據論文offset=
和mask=
都需要進行學習,
原文(The modulation scalar lies in the range [0,1], while
is a real number with unconstrained range.)就是
是一個0到1的數,這也很容易理解,它是一個模型回應引數,
隨便取,
原文(the initial values of and
are 0 and 0.5, respectively. )初始化,
概念理清楚,接下來準備實際操作,(為了更直觀,只取一個卷積層構建網路)
想要替換的正常卷積,代碼如下:
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
def forward(self, x):
out = self.relu(self.conv(x))
return out
使用deform_conv進行替換,代碼如下:
class net(nn.Module):
def __init__(self):
super(dcn, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) #原卷積
self.conv_offset = nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1)
init_offset = torch.Tensor(np.zeros([18, 3, 3, 3]))
self.conv_offset.weight = torch.nn.Parameter(init_offset) #初始化為0
self.conv_mask = nn.Conv2d(3, 9, kernel_size=3, stride=1, padding=1)
init_mask = torch.Tensor(np.zeros([9, 1, 3, 3])+np.array([0.5]))
self.conv_mask.weight = torch.nn.Parameter(init_mask) #初始化為0.5
def forward(self, x):
offset = self.conv_offset(x)
mask = torch.sigmoid(self.conv_mask(x)) #保證在0到1之間
out = torchvision.ops.deform_conv2d(input=x, offset=offset,
weight=self.conv.weight,
mask=mask, padding=(1, 1))
return out
需要注意的點有deform_conv2d的stride默認為(1, 1),padding默認為(0, 0),dilation默認為(1, 1),
ok!這樣就可以完美的將normal_conv替換成deform_conv了!不需要再去github上去看別人巨長的代碼了!感謝pytorch!
最后,隨便用mnist資料集跑跑(訓練集和驗證集都加入了隨機旋轉,網路為四層卷積,三層全連接,加入了dropout,引數都一樣,沒有仔細調整)
normal_conv

deform_conv

可以發現不僅收斂的更快,同時精度更高,deform_conv確實發揮了作用,
最后的最后,本人學藝不精,如果有任何問題,歡迎各位大神批評指正,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/432139.html
標籤:AI
上一篇:影像中的注意力機制詳解(SEBlock | ECABlock | CBAM)
下一篇:【數字信號處理】卷積編程實作 ( 卷積計算原理 | 卷積公式計算 | 使用 matlab 計算卷積 | 使用 C 語言實作卷積計算 )
