本文代碼地址https://github.com/njulhy/funny_code/blob/main/cnn_feature_visualization.ipynb
文章目錄
- CNN特征提取結果可視化——hooks簡單應用
- Hooks簡單介紹
- CNN特征提取的簡單可視化
- 創建CNN特征提取器
- 創建保存hook內容的物件
- 為卷積層注冊hook
- 讀取影像并進行特整體提取
- 查看卷積層特征提取效果
- 查看卷積層數
- 可視化第一個卷積層
- 可視化第二、七個卷積層
- 可視化第16個卷積層
- 結語
CNN特征提取結果可視化——hooks簡單應用
在神經網路搭建時可能出現各式各樣的錯誤,使用hook而非print或者簡單的斷點除錯有助于你更清晰的意識到錯誤所在,
hook的使用場景多種多樣,本文將使用hooks來簡單可視化卷積神經網路的特征提取,用到的神經網路框架為Pytorch
Hooks簡單介紹
每個hook都是預先定義好的可呼叫物件,在pytorch框架中,每個nn.Module物件都能夠方便地注冊(定義)一個hook,當一些trigger方法呼叫(如forward()和backward())后,注冊了hook的nn.Module物件會將相關資訊傳遞到hook里面去,
在PyTorch中,可以注冊三種hook:
- forward prehook (在forward之前執行)
- forward hook (在forward之后執行)
- backward hook (在backward之后執行)
具體理解每種hook的使用不是本文討論的范圍,我們將通過一個生動的卷積神經網路可視化例子來介紹hook的使用
CNN特征提取的簡單可視化
我們將要進行的作業包括:
- 創建CNN特征提取器,本文使用PyTorch自帶的resnet34
- 創建一個保存hook內容的物件
- 為每個卷積層創建hook
- 讀取影像并進行特征提取
- 查看卷積層特征提取效果
本文將對下圖進行特征提取并可視化
創建CNN特征提取器
import torch
import torchvision
feature_extractor = torchvision.models.resnet34(pretrained=True)
if torch.cuda.is_available():
feature_extractor.cuda()
創建保存hook內容的物件
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs=[]
save_output = SaveOutput()
為卷積層注冊hook
hook_handles = []
for layer in feature_extractor.modules():
if isinstance(layer, torch.nn.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
讀取影像并進行特整體提取
cat.jpg地址
from PIL import Image
from torchvision import transforms as T
image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)
out = feature_extractor(X)
查看卷積層特征提取效果
對于resnet來說,其具體結構如下:
卷積層共有1+6+(4*2+1)+(6*2+1)+(3*2+1)=36個,對conv3_x層有4*2+1卷積層的原因是(1)四個basicblock本身有4*2個卷積層(2)其中一個basicblock進行了downsample,又多了一個卷積層
查看卷積層數
此時每個卷積層的結果都通過hook保存到了save_output.outputs里面,我們查看是否為36個結果
可見全部卷積層的輸出都保存了下來
可視化第一個卷積層
對resnet34來說,首個卷積層的卷積核為7*7,將輸入的三通道彩色影像通道增加至64,尺寸從224*224對折為112*112,tensor的shape為1x64x112x112
我們對首個卷積層的提取結果進行可視化:
import matplotlib.pyplot as plt
plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[0].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))
emm這是第一個卷積層的提取結果,可愛的小貓咪開始黑化

可視化第二、七個卷積層
對resnet34來說,第2-7個卷積層tensor的shape為64x1x56x56,我們對其2個卷積層輸出進行可視化:
plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[1].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))
可見第二個卷積層的結果更加模糊一些

第2-7個卷積層tensor的shape為64x1x56x56,我們對第七個卷積層也可視化:
plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[6].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可視化第16個卷積層
第16個卷積層對應的是conv3_x的結果,其shape為1x128x28x28,可視化如下
plt.figure(figsize = (15,30))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[15].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))
可見影像經過多層特征提取,提取到的特征變得更加高層,大部分通道已經變得難以辨認

結語
對神經網路提取結果進行可視化有助于理解其特征提取逐漸高層化的程序,
hook的使用場景還有很多,希望小伙伴們繼續探索,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/275104.html
標籤:AI
