使用SppLayer 實作花的分類
目錄
使用SppLayer 實作花的分類 前言 一、Flatten,GlobalAveragePooling2D和SpatialPyramidPooling 1.Flatten 2.GlobalAveragePooling2D 3.SpatialPyramidPooling
二、Keras實作
總結 參考
前言
五一假期的第二篇博客,這篇文章我會總結一下卷積層和全連接層之間的“文章”,基于框架TensorFlow和Keras 去比較Flatten,GlobalAveragePooling2D和SPP(SpatialPyramidPooling 空間金字塔池化),并用SppLayer 實作了一個小例子,我剛開始學影像分類,所以所有都是基于影像分類的講述,文章中提到的三維塊是 以(HWC,(長,寬,通道))為準,不恰當的地方還請各位大佬指正,下面附上資源鏈接:
內容 鏈接 資料集 鏈接 kaggle上實作花分類例子 鏈接 Flatten,GlobalAveragePooling2D,SPP 鏈接 網盤 鏈接 提取碼:qp09
一、Flatten,GlobalAveragePooling2D和SpatialPyramidPooling
1.Flatten
Flatten的原理很簡單,就是強行把卷積后形成的三維(不包含batch_size)塊一維化,如下圖所示: 假設我們通過多層卷積池化卷積池化后得到7x7x5的三維塊,那么通過flatten之后就變成了7x7x5=245的一維向量,再連接一個和分類類別數量一致的全連接層,再添加softmax函式,這樣一個分類網路就搭建好了,那么這樣做的缺點是什么,假設我卷積后得到的塊比較大,例如我試驗中,輸入大小維224x224x3的圖片,用到vgg16(不包含全連接層)做前面的卷積運算,如下所示:
Model: "model_flatten"
_________________________________________________________________
Layer ( type) Output Shape Param #
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
vgg16 ( Functional) ( None, 7, 7, 512) 14714688
_________________________________________________________________
flatten ( Flatten) ( None, 25088) 0
_________________________________________________________________
dense ( Dense) ( None, 10) 250890
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
Total params: 14,965,578
Trainable params: 14,965,578
Non-trainable params: 0
_________________________________________________________________
卷積后的三維塊一維化之后得到了7x7x512=25088的一維向量,再連接接一個10個節點的全連接成做10種分類,那么這一層需要訓練的引數就是7x7x512x10+10=250890, 其中加的10是偏執(bias),那假如我要做1000種分類呢?那就是25088x1000+1000=25,089,000,引數太多了,還有一個缺點是,因為引數和圖片的長寬有關,所以輸入網路必須要指定圖片大小,
2.GlobalAveragePooling2D
GlobalAveragePooling2D是平均池化的一個特例,不需要指定pool_size和strides等引數,所求得到得就是每個通道的一個平均值,可以看出 它的計算和通道數有關,不多說,我們直接上圖: 如果你最后的分類剛好是五分類的話,那么很好,不用連接全連接層,就可以通過softmax函式計算,直接去分類了,如果這個通道數和你分類的類別不一樣,那么你可以再連接全連接層,然后再去分類,這樣的話,引數就會減少很多很多,,同樣已vgg16做卷積運算,再連接全域平均池化(GlobalAveragePooling2D),
Model: "model_global_avg_pooling"
_________________________________________________________________
Layer ( type) Output Shape Param #
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
vgg16 ( Functional) ( None, None, None, 512) 14714688
_________________________________________________________________
global_average_pooling2d ( Gl ( None, 512) 0
_________________________________________________________________
dense_1 ( Dense) ( None, 10) 5130
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
Total params: 14,719,818
Trainable params: 14,719,818
Non-trainable params: 0
_________________________________________________________________
這里需要注意的是,我不用再限定圖片的大小了!!通過vgg16卷積運算完成后,得到的是512通道的塊,經過GlobalAveragePooling2D后,他就變成512的一維向量了,再連接全連接層,假設10種分類的全連接層,那么這其中要訓練的引數就位512x10+10=5130,是不是相比上面的flatten產生的250890個引數要少很多了,
3.SpatialPyramidPooling
SpatialPyramidPooling ,空間金字塔池化,因為一開始我們輸入卷積網路的圖片需要指定其大小,假如圖片和輸入圖片大小不一致,我們一般都是通過裁剪或變形的方法,讓圖片能輸入到我們的網路,流程如下:
<style>#mermaid-svg-x4SPnSCitcHJPNG4 .label{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);fill:#333;color:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .label text{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .node rect,#mermaid-svg-x4SPnSCitcHJPNG4 .node circle,#mermaid-svg-x4SPnSCitcHJPNG4 .node ellipse,#mermaid-svg-x4SPnSCitcHJPNG4 .node polygon,#mermaid-svg-x4SPnSCitcHJPNG4 .node path{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-x4SPnSCitcHJPNG4 .node .label{text-align:center;fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .node.clickable{cursor:pointer}#mermaid-svg-x4SPnSCitcHJPNG4 .arrowheadPath{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .edgePath .path{stroke:#333;stroke-width:1.5px}#mermaid-svg-x4SPnSCitcHJPNG4 .flowchart-link{stroke:#333;fill:none}#mermaid-svg-x4SPnSCitcHJPNG4 .edgeLabel{background-color:#e8e8e8;text-align:center}#mermaid-svg-x4SPnSCitcHJPNG4 .edgeLabel rect{opacity:0.9}#mermaid-svg-x4SPnSCitcHJPNG4 .edgeLabel span{color:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .cluster rect{fill:#ffffde;stroke:#aa3;stroke-width:1px}#mermaid-svg-x4SPnSCitcHJPNG4 .cluster text{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:12px;background:#ffffde;border:1px solid #aa3;border-radius:2px;pointer-events:none;z-index:100}#mermaid-svg-x4SPnSCitcHJPNG4 .actor{stroke:#ccf;fill:#ECECFF}#mermaid-svg-x4SPnSCitcHJPNG4 text.actor>tspan{fill:#000;stroke:none}#mermaid-svg-x4SPnSCitcHJPNG4 .actor-line{stroke:grey}#mermaid-svg-x4SPnSCitcHJPNG4 .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .messageLine1{stroke-width:1.5;stroke-dasharray:2, 2;stroke:#333}#mermaid-svg-x4SPnSCitcHJPNG4 #arrowhead path{fill:#333;stroke:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .sequenceNumber{fill:#fff}#mermaid-svg-x4SPnSCitcHJPNG4 #sequencenumber{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 #crosshead path{fill:#333;stroke:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .messageText{fill:#333;stroke:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .labelBox{stroke:#ccf;fill:#ECECFF}#mermaid-svg-x4SPnSCitcHJPNG4 .labelText,#mermaid-svg-x4SPnSCitcHJPNG4 .labelText>tspan{fill:#000;stroke:none}#mermaid-svg-x4SPnSCitcHJPNG4 .loopText,#mermaid-svg-x4SPnSCitcHJPNG4 .loopText>tspan{fill:#000;stroke:none}#mermaid-svg-x4SPnSCitcHJPNG4 .loopLine{stroke-width:2px;stroke-dasharray:2, 2;stroke:#ccf;fill:#ccf}#mermaid-svg-x4SPnSCitcHJPNG4 .note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-x4SPnSCitcHJPNG4 .noteText,#mermaid-svg-x4SPnSCitcHJPNG4 .noteText>tspan{fill:#000;stroke:none}#mermaid-svg-x4SPnSCitcHJPNG4 .activation0{fill:#f4f4f4;stroke:#666}#mermaid-svg-x4SPnSCitcHJPNG4 .activation1{fill:#f4f4f4;stroke:#666}#mermaid-svg-x4SPnSCitcHJPNG4 .activation2{fill:#f4f4f4;stroke:#666}#mermaid-svg-x4SPnSCitcHJPNG4 .mermaid-main-font{font-family:"trebuchet ms", verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .section{stroke:none;opacity:0.2}#mermaid-svg-x4SPnSCitcHJPNG4 .section0{fill:rgba(102,102,255,0.49)}#mermaid-svg-x4SPnSCitcHJPNG4 .section2{fill:#fff400}#mermaid-svg-x4SPnSCitcHJPNG4 .section1,#mermaid-svg-x4SPnSCitcHJPNG4 .section3{fill:#fff;opacity:0.2}#mermaid-svg-x4SPnSCitcHJPNG4 .sectionTitle0{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .sectionTitle1{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .sectionTitle2{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .sectionTitle3{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .sectionTitle{text-anchor:start;font-size:11px;text-height:14px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .grid .tick{stroke:#d3d3d3;opacity:0.8;shape-rendering:crispEdges}#mermaid-svg-x4SPnSCitcHJPNG4 .grid .tick text{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .grid path{stroke-width:0}#mermaid-svg-x4SPnSCitcHJPNG4 .today{fill:none;stroke:red;stroke-width:2px}#mermaid-svg-x4SPnSCitcHJPNG4 .task{stroke-width:2}#mermaid-svg-x4SPnSCitcHJPNG4 .taskText{text-anchor:middle;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .taskText:not([font-size]){font-size:11px}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutsideRight{fill:#000;text-anchor:start;font-size:11px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutsideLeft{fill:#000;text-anchor:end;font-size:11px}#mermaid-svg-x4SPnSCitcHJPNG4 .task.clickable{cursor:pointer}#mermaid-svg-x4SPnSCitcHJPNG4 .taskText.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutsideLeft.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutsideRight.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-x4SPnSCitcHJPNG4 .taskText0,#mermaid-svg-x4SPnSCitcHJPNG4 .taskText1,#mermaid-svg-x4SPnSCitcHJPNG4 .taskText2,#mermaid-svg-x4SPnSCitcHJPNG4 .taskText3{fill:#fff}#mermaid-svg-x4SPnSCitcHJPNG4 .task0,#mermaid-svg-x4SPnSCitcHJPNG4 .task1,#mermaid-svg-x4SPnSCitcHJPNG4 .task2,#mermaid-svg-x4SPnSCitcHJPNG4 .task3{fill:#8a90dd;stroke:#534fbc}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutside0,#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutside2{fill:#000}#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutside1,#mermaid-svg-x4SPnSCitcHJPNG4 .taskTextOutside3{fill:#000}#mermaid-svg-x4SPnSCitcHJPNG4 .active0,#mermaid-svg-x4SPnSCitcHJPNG4 .active1,#mermaid-svg-x4SPnSCitcHJPNG4 .active2,#mermaid-svg-x4SPnSCitcHJPNG4 .active3{fill:#bfc7ff;stroke:#534fbc}#mermaid-svg-x4SPnSCitcHJPNG4 .activeText0,#mermaid-svg-x4SPnSCitcHJPNG4 .activeText1,#mermaid-svg-x4SPnSCitcHJPNG4 .activeText2,#mermaid-svg-x4SPnSCitcHJPNG4 .activeText3{fill:#000 !important}#mermaid-svg-x4SPnSCitcHJPNG4 .done0,#mermaid-svg-x4SPnSCitcHJPNG4 .done1,#mermaid-svg-x4SPnSCitcHJPNG4 .done2,#mermaid-svg-x4SPnSCitcHJPNG4 .done3{stroke:grey;fill:#d3d3d3;stroke-width:2}#mermaid-svg-x4SPnSCitcHJPNG4 .doneText0,#mermaid-svg-x4SPnSCitcHJPNG4 .doneText1,#mermaid-svg-x4SPnSCitcHJPNG4 .doneText2,#mermaid-svg-x4SPnSCitcHJPNG4 .doneText3{fill:#000 !important}#mermaid-svg-x4SPnSCitcHJPNG4 .crit0,#mermaid-svg-x4SPnSCitcHJPNG4 .crit1,#mermaid-svg-x4SPnSCitcHJPNG4 .crit2,#mermaid-svg-x4SPnSCitcHJPNG4 .crit3{stroke:#f88;fill:red;stroke-width:2}#mermaid-svg-x4SPnSCitcHJPNG4 .activeCrit0,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCrit1,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCrit2,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCrit3{stroke:#f88;fill:#bfc7ff;stroke-width:2}#mermaid-svg-x4SPnSCitcHJPNG4 .doneCrit0,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCrit1,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCrit2,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCrit3{stroke:#f88;fill:#d3d3d3;stroke-width:2;cursor:pointer;shape-rendering:crispEdges}#mermaid-svg-x4SPnSCitcHJPNG4 .milestone{transform:rotate(45deg) scale(0.8, 0.8)}#mermaid-svg-x4SPnSCitcHJPNG4 .milestoneText{font-style:italic}#mermaid-svg-x4SPnSCitcHJPNG4 .doneCritText0,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCritText1,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCritText2,#mermaid-svg-x4SPnSCitcHJPNG4 .doneCritText3{fill:#000 !important}#mermaid-svg-x4SPnSCitcHJPNG4 .activeCritText0,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCritText1,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCritText2,#mermaid-svg-x4SPnSCitcHJPNG4 .activeCritText3{fill:#000 !important}#mermaid-svg-x4SPnSCitcHJPNG4 .titleText{text-anchor:middle;font-size:18px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 g.classGroup text{fill:#9370db;stroke:none;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:10px}#mermaid-svg-x4SPnSCitcHJPNG4 g.classGroup text .title{font-weight:bolder}#mermaid-svg-x4SPnSCitcHJPNG4 g.clickable{cursor:pointer}#mermaid-svg-x4SPnSCitcHJPNG4 g.classGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-x4SPnSCitcHJPNG4 g.classGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 .classLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.5}#mermaid-svg-x4SPnSCitcHJPNG4 .classLabel .label{fill:#9370db;font-size:10px}#mermaid-svg-x4SPnSCitcHJPNG4 .relation{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-x4SPnSCitcHJPNG4 .dashed-line{stroke-dasharray:3}#mermaid-svg-x4SPnSCitcHJPNG4 #compositionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #compositionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #aggregationStart{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #aggregationEnd{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #dependencyStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #dependencyEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #extensionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 #extensionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 .commit-id,#mermaid-svg-x4SPnSCitcHJPNG4 .commit-msg,#mermaid-svg-x4SPnSCitcHJPNG4 .branch-label{fill:lightgrey;color:lightgrey;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .pieTitleText{text-anchor:middle;font-size:25px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .slice{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 g.stateGroup text{fill:#9370db;stroke:none;font-size:10px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 g.stateGroup text{fill:#9370db;fill:#333;stroke:none;font-size:10px}#mermaid-svg-x4SPnSCitcHJPNG4 g.statediagram-cluster .cluster-label text{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 g.stateGroup .state-title{font-weight:bolder;fill:#000}#mermaid-svg-x4SPnSCitcHJPNG4 g.stateGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-x4SPnSCitcHJPNG4 g.stateGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-x4SPnSCitcHJPNG4 .transition{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-x4SPnSCitcHJPNG4 .stateGroup .composit{fill:white;border-bottom:1px}#mermaid-svg-x4SPnSCitcHJPNG4 .stateGroup .alt-composit{fill:#e0e0e0;border-bottom:1px}#mermaid-svg-x4SPnSCitcHJPNG4 .state-note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-x4SPnSCitcHJPNG4 .state-note text{fill:black;stroke:none;font-size:10px}#mermaid-svg-x4SPnSCitcHJPNG4 .stateLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.7}#mermaid-svg-x4SPnSCitcHJPNG4 .edgeLabel text{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .stateLabel text{fill:#000;font-size:10px;font-weight:bold;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-x4SPnSCitcHJPNG4 .node circle.state-start{fill:black;stroke:black}#mermaid-svg-x4SPnSCitcHJPNG4 .node circle.state-end{fill:black;stroke:white;stroke-width:1.5}#mermaid-svg-x4SPnSCitcHJPNG4 #statediagram-barbEnd{fill:#9370db}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-cluster rect{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-cluster rect.outer{rx:5px;ry:5px}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-state .divider{stroke:#9370db}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-state .title-state{rx:5px;ry:5px}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-cluster.statediagram-cluster .inner{fill:white}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-cluster.statediagram-cluster-alt .inner{fill:#e0e0e0}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-cluster .inner{rx:0;ry:0}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-state rect.basic{rx:5px;ry:5px}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-state rect.divider{stroke-dasharray:10,10;fill:#efefef}#mermaid-svg-x4SPnSCitcHJPNG4 .note-edge{stroke-dasharray:5}#mermaid-svg-x4SPnSCitcHJPNG4 .statediagram-note rect{fill:#fff5ad;stroke:#aa3;stroke-width:1px;rx:0;ry:0}:root{--mermaid-font-family: '"trebuchet ms", verdana, arial';--mermaid-font-family: "Comic Sans MS", "Comic Sans", cursive}#mermaid-svg-x4SPnSCitcHJPNG4 .error-icon{fill:#522}#mermaid-svg-x4SPnSCitcHJPNG4 .error-text{fill:#522;stroke:#522}#mermaid-svg-x4SPnSCitcHJPNG4 .edge-thickness-normal{stroke-width:2px}#mermaid-svg-x4SPnSCitcHJPNG4 .edge-thickness-thick{stroke-width:3.5px}#mermaid-svg-x4SPnSCitcHJPNG4 .edge-pattern-solid{stroke-dasharray:0}#mermaid-svg-x4SPnSCitcHJPNG4 .edge-pattern-dashed{stroke-dasharray:3}#mermaid-svg-x4SPnSCitcHJPNG4 .edge-pattern-dotted{stroke-dasharray:2}#mermaid-svg-x4SPnSCitcHJPNG4 .marker{fill:#333}#mermaid-svg-x4SPnSCitcHJPNG4 .marker.cross{stroke:#333}
:root { --mermaid-font-family: "trebuchet ms", verdana, arial;}</style>
<style>#mermaid-svg-x4SPnSCitcHJPNG4 {
color: rgba(0, 0, 0, 0.75);
font: ;
}</style>
image
crop/warp
conv layers
fc layers
output
但是這樣做的壞處就是,也顯而易見了, 這讓我們要識別目標看起來不像目標,所以引入了空間金字塔結構,如下圖所示: 假設我經過卷積池化卷積池化后得到一組大小為HXWX256的特征圖,分別用4x4,2x2,1*1的塊去做最大池化,再將得到的16x256-d,4x256d,1x256-d,組合成一個向量,作為輸出,這樣的話全連接層也跟圖片的大小沒有關系了!關注的是通道數,具體操作流程如下:
<style>#mermaid-svg-uZdmXEqeMQLa8tRj .label{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);fill:#333;color:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .label text{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .node rect,#mermaid-svg-uZdmXEqeMQLa8tRj .node circle,#mermaid-svg-uZdmXEqeMQLa8tRj .node ellipse,#mermaid-svg-uZdmXEqeMQLa8tRj .node polygon,#mermaid-svg-uZdmXEqeMQLa8tRj .node path{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-uZdmXEqeMQLa8tRj .node .label{text-align:center;fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .node.clickable{cursor:pointer}#mermaid-svg-uZdmXEqeMQLa8tRj .arrowheadPath{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .edgePath .path{stroke:#333;stroke-width:1.5px}#mermaid-svg-uZdmXEqeMQLa8tRj .flowchart-link{stroke:#333;fill:none}#mermaid-svg-uZdmXEqeMQLa8tRj .edgeLabel{background-color:#e8e8e8;text-align:center}#mermaid-svg-uZdmXEqeMQLa8tRj .edgeLabel rect{opacity:0.9}#mermaid-svg-uZdmXEqeMQLa8tRj .edgeLabel span{color:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .cluster rect{fill:#ffffde;stroke:#aa3;stroke-width:1px}#mermaid-svg-uZdmXEqeMQLa8tRj .cluster text{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:12px;background:#ffffde;border:1px solid #aa3;border-radius:2px;pointer-events:none;z-index:100}#mermaid-svg-uZdmXEqeMQLa8tRj .actor{stroke:#ccf;fill:#ECECFF}#mermaid-svg-uZdmXEqeMQLa8tRj text.actor>tspan{fill:#000;stroke:none}#mermaid-svg-uZdmXEqeMQLa8tRj .actor-line{stroke:grey}#mermaid-svg-uZdmXEqeMQLa8tRj .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .messageLine1{stroke-width:1.5;stroke-dasharray:2, 2;stroke:#333}#mermaid-svg-uZdmXEqeMQLa8tRj #arrowhead path{fill:#333;stroke:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .sequenceNumber{fill:#fff}#mermaid-svg-uZdmXEqeMQLa8tRj #sequencenumber{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj #crosshead path{fill:#333;stroke:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .messageText{fill:#333;stroke:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .labelBox{stroke:#ccf;fill:#ECECFF}#mermaid-svg-uZdmXEqeMQLa8tRj .labelText,#mermaid-svg-uZdmXEqeMQLa8tRj .labelText>tspan{fill:#000;stroke:none}#mermaid-svg-uZdmXEqeMQLa8tRj .loopText,#mermaid-svg-uZdmXEqeMQLa8tRj .loopText>tspan{fill:#000;stroke:none}#mermaid-svg-uZdmXEqeMQLa8tRj .loopLine{stroke-width:2px;stroke-dasharray:2, 2;stroke:#ccf;fill:#ccf}#mermaid-svg-uZdmXEqeMQLa8tRj .note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-uZdmXEqeMQLa8tRj .noteText,#mermaid-svg-uZdmXEqeMQLa8tRj .noteText>tspan{fill:#000;stroke:none}#mermaid-svg-uZdmXEqeMQLa8tRj .activation0{fill:#f4f4f4;stroke:#666}#mermaid-svg-uZdmXEqeMQLa8tRj .activation1{fill:#f4f4f4;stroke:#666}#mermaid-svg-uZdmXEqeMQLa8tRj .activation2{fill:#f4f4f4;stroke:#666}#mermaid-svg-uZdmXEqeMQLa8tRj .mermaid-main-font{font-family:"trebuchet ms", verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .section{stroke:none;opacity:0.2}#mermaid-svg-uZdmXEqeMQLa8tRj .section0{fill:rgba(102,102,255,0.49)}#mermaid-svg-uZdmXEqeMQLa8tRj .section2{fill:#fff400}#mermaid-svg-uZdmXEqeMQLa8tRj .section1,#mermaid-svg-uZdmXEqeMQLa8tRj .section3{fill:#fff;opacity:0.2}#mermaid-svg-uZdmXEqeMQLa8tRj .sectionTitle0{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .sectionTitle1{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .sectionTitle2{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .sectionTitle3{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .sectionTitle{text-anchor:start;font-size:11px;text-height:14px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .grid .tick{stroke:#d3d3d3;opacity:0.8;shape-rendering:crispEdges}#mermaid-svg-uZdmXEqeMQLa8tRj .grid .tick text{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .grid path{stroke-width:0}#mermaid-svg-uZdmXEqeMQLa8tRj .today{fill:none;stroke:red;stroke-width:2px}#mermaid-svg-uZdmXEqeMQLa8tRj .task{stroke-width:2}#mermaid-svg-uZdmXEqeMQLa8tRj .taskText{text-anchor:middle;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .taskText:not([font-size]){font-size:11px}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutsideRight{fill:#000;text-anchor:start;font-size:11px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutsideLeft{fill:#000;text-anchor:end;font-size:11px}#mermaid-svg-uZdmXEqeMQLa8tRj .task.clickable{cursor:pointer}#mermaid-svg-uZdmXEqeMQLa8tRj .taskText.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutsideLeft.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutsideRight.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-uZdmXEqeMQLa8tRj .taskText0,#mermaid-svg-uZdmXEqeMQLa8tRj .taskText1,#mermaid-svg-uZdmXEqeMQLa8tRj .taskText2,#mermaid-svg-uZdmXEqeMQLa8tRj .taskText3{fill:#fff}#mermaid-svg-uZdmXEqeMQLa8tRj .task0,#mermaid-svg-uZdmXEqeMQLa8tRj .task1,#mermaid-svg-uZdmXEqeMQLa8tRj .task2,#mermaid-svg-uZdmXEqeMQLa8tRj .task3{fill:#8a90dd;stroke:#534fbc}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutside0,#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutside2{fill:#000}#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutside1,#mermaid-svg-uZdmXEqeMQLa8tRj .taskTextOutside3{fill:#000}#mermaid-svg-uZdmXEqeMQLa8tRj .active0,#mermaid-svg-uZdmXEqeMQLa8tRj .active1,#mermaid-svg-uZdmXEqeMQLa8tRj .active2,#mermaid-svg-uZdmXEqeMQLa8tRj .active3{fill:#bfc7ff;stroke:#534fbc}#mermaid-svg-uZdmXEqeMQLa8tRj .activeText0,#mermaid-svg-uZdmXEqeMQLa8tRj .activeText1,#mermaid-svg-uZdmXEqeMQLa8tRj .activeText2,#mermaid-svg-uZdmXEqeMQLa8tRj .activeText3{fill:#000 !important}#mermaid-svg-uZdmXEqeMQLa8tRj .done0,#mermaid-svg-uZdmXEqeMQLa8tRj .done1,#mermaid-svg-uZdmXEqeMQLa8tRj .done2,#mermaid-svg-uZdmXEqeMQLa8tRj .done3{stroke:grey;fill:#d3d3d3;stroke-width:2}#mermaid-svg-uZdmXEqeMQLa8tRj .doneText0,#mermaid-svg-uZdmXEqeMQLa8tRj .doneText1,#mermaid-svg-uZdmXEqeMQLa8tRj .doneText2,#mermaid-svg-uZdmXEqeMQLa8tRj .doneText3{fill:#000 !important}#mermaid-svg-uZdmXEqeMQLa8tRj .crit0,#mermaid-svg-uZdmXEqeMQLa8tRj .crit1,#mermaid-svg-uZdmXEqeMQLa8tRj .crit2,#mermaid-svg-uZdmXEqeMQLa8tRj .crit3{stroke:#f88;fill:red;stroke-width:2}#mermaid-svg-uZdmXEqeMQLa8tRj .activeCrit0,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCrit1,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCrit2,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCrit3{stroke:#f88;fill:#bfc7ff;stroke-width:2}#mermaid-svg-uZdmXEqeMQLa8tRj .doneCrit0,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCrit1,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCrit2,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCrit3{stroke:#f88;fill:#d3d3d3;stroke-width:2;cursor:pointer;shape-rendering:crispEdges}#mermaid-svg-uZdmXEqeMQLa8tRj .milestone{transform:rotate(45deg) scale(0.8, 0.8)}#mermaid-svg-uZdmXEqeMQLa8tRj .milestoneText{font-style:italic}#mermaid-svg-uZdmXEqeMQLa8tRj .doneCritText0,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCritText1,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCritText2,#mermaid-svg-uZdmXEqeMQLa8tRj .doneCritText3{fill:#000 !important}#mermaid-svg-uZdmXEqeMQLa8tRj .activeCritText0,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCritText1,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCritText2,#mermaid-svg-uZdmXEqeMQLa8tRj .activeCritText3{fill:#000 !important}#mermaid-svg-uZdmXEqeMQLa8tRj .titleText{text-anchor:middle;font-size:18px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj g.classGroup text{fill:#9370db;stroke:none;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:10px}#mermaid-svg-uZdmXEqeMQLa8tRj g.classGroup text .title{font-weight:bolder}#mermaid-svg-uZdmXEqeMQLa8tRj g.clickable{cursor:pointer}#mermaid-svg-uZdmXEqeMQLa8tRj g.classGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-uZdmXEqeMQLa8tRj g.classGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj .classLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.5}#mermaid-svg-uZdmXEqeMQLa8tRj .classLabel .label{fill:#9370db;font-size:10px}#mermaid-svg-uZdmXEqeMQLa8tRj .relation{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-uZdmXEqeMQLa8tRj .dashed-line{stroke-dasharray:3}#mermaid-svg-uZdmXEqeMQLa8tRj #compositionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #compositionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #aggregationStart{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #aggregationEnd{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #dependencyStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #dependencyEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #extensionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj #extensionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj .commit-id,#mermaid-svg-uZdmXEqeMQLa8tRj .commit-msg,#mermaid-svg-uZdmXEqeMQLa8tRj .branch-label{fill:lightgrey;color:lightgrey;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .pieTitleText{text-anchor:middle;font-size:25px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .slice{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj g.stateGroup text{fill:#9370db;stroke:none;font-size:10px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj g.stateGroup text{fill:#9370db;fill:#333;stroke:none;font-size:10px}#mermaid-svg-uZdmXEqeMQLa8tRj g.statediagram-cluster .cluster-label text{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj g.stateGroup .state-title{font-weight:bolder;fill:#000}#mermaid-svg-uZdmXEqeMQLa8tRj g.stateGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-uZdmXEqeMQLa8tRj g.stateGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-uZdmXEqeMQLa8tRj .transition{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-uZdmXEqeMQLa8tRj .stateGroup .composit{fill:white;border-bottom:1px}#mermaid-svg-uZdmXEqeMQLa8tRj .stateGroup .alt-composit{fill:#e0e0e0;border-bottom:1px}#mermaid-svg-uZdmXEqeMQLa8tRj .state-note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-uZdmXEqeMQLa8tRj .state-note text{fill:black;stroke:none;font-size:10px}#mermaid-svg-uZdmXEqeMQLa8tRj .stateLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.7}#mermaid-svg-uZdmXEqeMQLa8tRj .edgeLabel text{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .stateLabel text{fill:#000;font-size:10px;font-weight:bold;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-uZdmXEqeMQLa8tRj .node circle.state-start{fill:black;stroke:black}#mermaid-svg-uZdmXEqeMQLa8tRj .node circle.state-end{fill:black;stroke:white;stroke-width:1.5}#mermaid-svg-uZdmXEqeMQLa8tRj #statediagram-barbEnd{fill:#9370db}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-cluster rect{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-cluster rect.outer{rx:5px;ry:5px}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-state .divider{stroke:#9370db}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-state .title-state{rx:5px;ry:5px}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-cluster.statediagram-cluster .inner{fill:white}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-cluster.statediagram-cluster-alt .inner{fill:#e0e0e0}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-cluster .inner{rx:0;ry:0}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-state rect.basic{rx:5px;ry:5px}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-state rect.divider{stroke-dasharray:10,10;fill:#efefef}#mermaid-svg-uZdmXEqeMQLa8tRj .note-edge{stroke-dasharray:5}#mermaid-svg-uZdmXEqeMQLa8tRj .statediagram-note rect{fill:#fff5ad;stroke:#aa3;stroke-width:1px;rx:0;ry:0}:root{--mermaid-font-family: '"trebuchet ms", verdana, arial';--mermaid-font-family: "Comic Sans MS", "Comic Sans", cursive}#mermaid-svg-uZdmXEqeMQLa8tRj .error-icon{fill:#522}#mermaid-svg-uZdmXEqeMQLa8tRj .error-text{fill:#522;stroke:#522}#mermaid-svg-uZdmXEqeMQLa8tRj .edge-thickness-normal{stroke-width:2px}#mermaid-svg-uZdmXEqeMQLa8tRj .edge-thickness-thick{stroke-width:3.5px}#mermaid-svg-uZdmXEqeMQLa8tRj .edge-pattern-solid{stroke-dasharray:0}#mermaid-svg-uZdmXEqeMQLa8tRj .edge-pattern-dashed{stroke-dasharray:3}#mermaid-svg-uZdmXEqeMQLa8tRj .edge-pattern-dotted{stroke-dasharray:2}#mermaid-svg-uZdmXEqeMQLa8tRj .marker{fill:#333}#mermaid-svg-uZdmXEqeMQLa8tRj .marker.cross{stroke:#333}
:root { --mermaid-font-family: "trebuchet ms", verdana, arial;}</style>
<style>#mermaid-svg-uZdmXEqeMQLa8tRj {
color: rgba(0, 0, 0, 0.75);
font: ;
}</style>
image
conv layers
spatial pyramid pooling
fc layers
output
二、Keras實作
1.注意事項
說了這么多,我們還是來看看代碼把,博主能力有限,沒能去手擼SppLayer的代碼,參考的是GitHub上大佬的代碼(連接),他使用的Keras版本比較低,所以我在這里做了一下簡單的修改,使其在2.4版本上可用,Flatten和GlobalAveragePooling2D keras中都有實作,感興趣可以去看一下原始碼,
2.實作程序
首先是圖片讀取和資料可視化函式,上篇博客使用到了,這里總結一下,方便以后復用,
#首先定義兩個工具函式吧,一個是可視化資料,一個是圖片分類檔案讀取,還有一個是SPPLayer
import cv2
from PIL import Image
import numpy as np
import matplotlib. pyplot as plt
import os
def visual_train_data ( train_path, classes) : #資料可視化
"""
:param train_path: 訓練資料路徑
:param classes: 標簽字典 如classes = { 0:'Speed limit (20km/h)',
1:'Speed limit (30km/h)',
2:'Speed limit (50km/h)',
3:'Speed limit (60km/h)',
4:'Speed limit (70km/h)'}
:return: None
"""
print ( classes)
floders = os. listdir( train_path)
train_num = [ ]
class_num = [ ]
index= 0
for floder in floders:
print ( floder)
if floder== 'flowers' : #這里的代碼是我在使用在線資料集的時候,讀取資料會多出現一個flowers的檔案夾,這里要給他踢掉
continue
train_files = os. listdir( train_path + '/' + floder)
train_num. append( len ( train_files) )
class_num. append( classes[ index] )
index+= 1
zipped_lists = zip ( train_num, class_num)
sorted_pair = sorted ( zipped_lists)
tuples = zip ( * sorted_pair) # 這里是解壓
train_num, class_num = [ list ( tuple ) for tuple in tuples]
plt. figure( figsize= ( 21 , 10 ) )
plt. bar( class_num, train_num)
plt. xticks( class_num, rotation= 'vertical' )
plt. show( )
return train_num, class_num
#首先定義兩個工具函式吧,一個是可視化資料,一個是圖片分類檔案讀取,還有一個是SPPLayer
import cv2
from PIL import Image
import numpy as np
import matplotlib. pyplot as plt
import os
def visual_train_data ( train_path, classes) : #資料可視化
"""
:param train_path: 訓練資料路徑
:param classes: 標簽字典 如classes = { 0:'Speed limit (20km/h)',
1:'Speed limit (30km/h)',
2:'Speed limit (50km/h)',
3:'Speed limit (60km/h)',
4:'Speed limit (70km/h)'}
:return: None
"""
print ( classes)
floders = os. listdir( train_path)
train_num = [ ]
class_num = [ ]
index= 0
for floder in floders:
print ( floder)
if floder== 'flowers' :
continue
train_files = os. listdir( train_path + '/' + floder)
train_num. append( len ( train_files) )
class_num. append( classes[ index] )
index+= 1
zipped_lists = zip ( train_num, class_num)
sorted_pair = sorted ( zipped_lists)
tuples = zip ( * sorted_pair) # 這里是解壓
train_num, class_num = [ list ( tuple ) for tuple in tuples]
plt. figure( figsize= ( 21 , 10 ) )
plt. bar( class_num, train_num)
plt. xticks( class_num, rotation= 'vertical' )
plt. show( )
return train_num, class_num
def load_train_data ( train_data_dir, imgage_shape= ( None , None , 3 ) ) :
"""
:param train_data_dir: 訓練集路徑
:param imgage_shape: 圖片的長款及通道
:return: image_data(np.array) 圖片資料 image_labels 資料標簽
"""
IMG_HEIGHT= imgage_shape[ 0 ]
IMG_WEIGHT= imgage_shape[ 1 ]
img_resize= None
floders = os. listdir( train_data_dir)
image_data = [ ] # 用于保存分類
image_labels = [ ] # 用于保存標簽
type_dict = { } # 下表和所屬類別對應
index = - 1 # 用于字典下表和標簽
for floder in floders:
path = train_data_dir + '/' + floder
if floder== 'flowers' :
continue
print ( 'loading ' + path)
index += 1 # 從0開始編號
type_dict[ index] = floder. split( '-' ) [ - 1 ]
images = os. listdir( path)
for img in images:
try : # 加入例外判斷 防止讀取德時候 出錯
image = cv2. imread( path + '/' + img)
if IMG_WEIGHT is not None and IMG_HEIGHT is not None :
img_resize = cv2. resize( image, ( IMG_WEIGHT, IMG_HEIGHT) )
else :
img_resize = image
image_data. append( img_resize)
image_labels. append( index)
except Exception as err:
print ( err)
print ( 'Error in ' + img)
image_data = np. array( image_data, np. float32)
image_labels = np. array( image_labels, np. int )
print ( "loading finished" )
print ( "image_data shape " , image_data. shape)
print ( "image_labels shape" , image_labels. shape)
return image_data, image_labels, type_dict
#資料加載
train_data_dir= '../input/flowers-recognition/flowers'
image_data, image_labels, dict_type= load_train_data( train_data_dir, ( 224 , 224 , 3 ) )
print ( dict_type)
visual_train_data( train_data_dir, dict_type) #資料數量統計
統計出來大概是這樣一張圖:
#SppLayer 網路結構定義哦 GitHub提供 我自己沒有寫哦
import tensorflow as tf
# from keras.engine.topology import Layer
from tensorflow. python. keras. layers import Layer
import tensorflow. keras. backend as K
class SpatialPyramidPooling ( Layer) :
"""Spatial pyramid pooling layer for 2D inputs.
See Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition,
K. He, X. Zhang, S. Ren, J. Sun
# Arguments
pool_list: list of int
List of pooling regions to use. The length of the list is the number of pooling regions,
each int in the list is the number of regions in that pool. For example [1,2,4] would be 3
regions with 1, 2x2 and 4x4 max pools, so 21 outputs per feature map
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if dim_ordering='tf'.
# Output shape
2D tensor with shape:
`(samples, channels * sum([i * i for i in pool_list])`
"""
def __init__ ( self, pool_list, ** kwargs) :
self. dim_ordering = K. image_data_format( )
assert self. dim_ordering in { 'channels_first' , 'channels_last' } , 'dim_ordering must be in {channels_first, channels_last}'
self. pool_list = pool_list
self. num_outputs_per_channel = sum ( [ i * i for i in pool_list] )
super ( SpatialPyramidPooling, self) . __init__( ** kwargs)
def build ( self, input_shape) :
if self. dim_ordering == 'channels_first' :
self. nb_channels = input_shape[ 1 ]
elif self. dim_ordering == 'channels_last' :
self. nb_channels = input_shape[ 3 ]
def compute_output_shape ( self, input_shape) :
return ( input_shape[ 0 ] , self. nb_channels * self. num_outputs_per_channel)
def get_config ( self) :
config = { 'pool_list' : self. pool_list}
base_config = super ( SpatialPyramidPooling, self) . get_config( )
return dict ( list ( base_config. items( ) ) + list ( config. items( ) ) )
def call ( self, x, mask= None ) :
input_shape = K. shape( x)
if self. dim_ordering == 'channels_first' :
num_rows = input_shape[ 2 ]
num_cols = input_shape[ 3 ]
elif self. dim_ordering == 'channels_last' :
num_rows = input_shape[ 1 ]
num_cols = input_shape[ 2 ]
row_length = [ K. cast( num_rows, 'float32' ) / i for i in self. pool_list]
col_length = [ K. cast( num_cols, 'float32' ) / i for i in self. pool_list]
outputs = [ ]
if self. dim_ordering == 'channels_first' :
for pool_num, num_pool_regions in enumerate ( self. pool_list) :
for jy in range ( num_pool_regions) :
for ix in range ( num_pool_regions) :
x1 = ix * col_length[ pool_num]
x2 = ix * col_length[ pool_num] + col_length[ pool_num]
y1 = jy * row_length[ pool_num]
y2 = jy * row_length[ pool_num] + row_length[ pool_num]
x1 = K. cast( K. round ( x1) , 'int32' )
x2 = K. cast( K. round ( x2) , 'int32' )
y1 = K. cast( K. round ( y1) , 'int32' )
y2 = K. cast( K. round ( y2) , 'int32' )
new_shape = [ input_shape[ 0 ] , input_shape[ 1 ] ,
y2 - y1, x2 - x1]
x_crop = x[ : , : , y1: y2, x1: x2]
xm = K. reshape( x_crop, new_shape)
pooled_val = K. max ( xm, axis= ( 2 , 3 ) )
outputs. append( pooled_val)
elif self. dim_ordering == 'channels_last' :
for pool_num, num_pool_regions in enumerate ( self. pool_list) :
for jy in range ( num_pool_regions) :
for ix in range ( num_pool_regions) :
x1 = ix * col_length[ pool_num]
x2 = ix * col_length[ pool_num] + col_length[ pool_num]
y1 = jy * row_length[ pool_num]
y2 = jy * row_length[ pool_num] + row_length[ pool_num]
x1 = K. cast( K. round ( x1) , 'int32' )
x2 = K. cast( K. round ( x2) , 'int32' )
y1 = K. cast( K. round ( y1) , 'int32' )
y2 = K. cast( K. round ( y2) , 'int32' )
new_shape = [ input_shape[ 0 ] , y2 - y1,
x2 - x1, input_shape[ 3 ] ]
x_crop = x[ : , y1: y2, x1: x2, : ]
xm = K. reshape( x_crop, new_shape)
pooled_val = K. max ( xm, axis= ( 1 , 2 ) )
outputs. append( pooled_val)
if self. dim_ordering == 'channels_first' :
outputs = K. concatenate( outputs)
elif self. dim_ordering == 'channels_last' :
#outputs = K.concatenate(outputs,axis = 1)
outputs = K. concatenate( outputs)
# outputs = K.reshape(outputs,(len(self.pool_list),self.num_outputs_per_channel,input_shape[0],input_shape[1]))
#outputs = K.permute_dimensions(outputs,(3,1,0,2))
outputs = K. reshape( outputs, ( input_shape[ 0 ] , self. num_outputs_per_channel * self. nb_channels) )
return outputs
主要修改的地方如下表:
原始碼 修改 K.image_dim_ordering() K.image_data_format() th channels_first tf channels_last
構建自己的模型,由于本次只是一個學習案例,資料集也沒有提供測驗集,沒有去要求其準確率,
#定義網路結構
import tensorflow as tf
from tensorflow. keras. layers import Conv2D, Flatten, Lambda, MaxPooling2D, Dropout, Input, Dense, ZeroPadding2D, BatchNormalization
from tensorflow. python. keras import backend
from tensorflow. python. keras. engine import training
from tensorflow. python. keras. utils import layer_utils
from tensorflow. keras import optimizers, losses, initializers
def Spp_test_model ( input_shape= ( None , None , 3 ) , input_tensor= None , classes= 5 ) :
if input_tensor is None :
img_input = Input( shape= input_shape)
else :
if not backend. is_keras_tensor( input_tensor) :
img_input = Input( tensor= input_tensor, shape= input_shape)
else :
img_input = input_tensor
#搭建網路模型結構,一開始我并不限定輸入大小
#第一個塊
x = Conv2D( filters= 32 , kernel_size= ( 3 , 3 ) , strides= 1 , padding= 'valid' , name= 'conv_block_1' , activation= 'relu' ) (
img_input)
x = MaxPooling2D( pool_size= ( 2 , 2 ) , strides= 1 , name= 'max_pooling_1' ) ( x)
x= BatchNormalization( ) ( x)
#第二個塊
x = Conv2D( filters= 64 , kernel_size= ( 3 , 3 ) , strides= 1 , padding= 'same' , name= 'conv_block_2' , activation= 'relu' ) (
x)
x = MaxPooling2D( pool_size= ( 2 , 2 ) , strides= 1 , name= 'max_pooling_2' ) ( x)
x= BatchNormalization( ) ( x)
#第三個塊
x = Conv2D( filters= 32 , kernel_size= ( 3 , 3 ) , strides= 1 , padding= 'same' , name= 'conv_block_3' , activation= 'relu' ) (
x)
x = MaxPooling2D( pool_size= ( 2 , 2 ) , strides= 1 , name= 'max_pooling_3' ) ( x)
x= BatchNormalization( ) ( x)
x= SpatialPyramidPooling( [ 1 , 2 , 4 ] ) ( x) #這里就是上述對應的1x1,2x2,4x4的 塊
x= Dense( 5 , activation= 'softmax' ) ( x)
#這里就是最重要得SppLayer了
if input_tensor is not None :
inputs = layer_utils. get_source_inputs( input_tensor)
else :
inputs = img_input
model = training. Model( inputs, x, name= 'SppNet' )
return model
model = Spp_test_model( classes= 5 )
model. summary( )
#劃分為訓練集和驗證集
from tensorflow import keras
from sklearn. model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split( image_data, image_labels, train_size= 0.7 , random_state= 42 ,
shuffle= True )
del image_data
del image_labels
X_train = X_train/ 255.0 #歸一化
X_val = X_val / 255.0 #歸一化
y_train = keras. utils. to_categorical( y_train, 5 )
y_val = keras. utils. to_categorical( y_val, 5 )
print ( "X_train.shape" , X_train. shape)
print ( "X_valid.shape" , X_val. shape)
print ( "y_train.shape" , y_train. shape)
print ( "y_valid.shape" , y_val. shape)
#訓練程序
from tensorflow. keras. preprocessing. image import ImageDataGenerator
from tensorflow. keras. optimizers import Adam, SGD
lr= 0.0001
epochs= 15
opt= Adam( lr= lr, decay= lr/ ( epochs/ 0.5 ) )
model. compile ( loss= 'categorical_crossentropy' , optimizer= opt, metrics= [ 'acc' ] )
aug = ImageDataGenerator(
rotation_range= 10 ,
zoom_range= 0.15 ,
width_shift_range= 0.1 ,
height_shift_range= 0.1 ,
shear_range= 0.15 ,
horizontal_flip= False ,
vertical_flip= False ,
fill_mode= 'nearest'
)
history = model. fit( X_train, y_train, batch_size= 50 ,
epochs= epochs, validation_data= ( X_val, y_val) )
網路結構:
Model: "SppNet"
_________________________________________________________________
Layer ( type) Output Shape Param #
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
input_1 ( InputLayer) [ ( None, None, None, 3) ] 0
_________________________________________________________________
conv_block_1 ( Conv2D) ( None, None, None, 32) 896
_________________________________________________________________
max_pooling_1 ( MaxPooling2D) ( None, None, None, 32) 0
_________________________________________________________________
batch_normalization ( BatchNo ( None, None, None, 32) 128
_________________________________________________________________
conv_block_2 ( Conv2D) ( None, None, None, 64) 18496
_________________________________________________________________
max_pooling_2 ( MaxPooling2D) ( None, None, None, 64) 0
_________________________________________________________________
batch_normalization_1 ( Batch ( None, None, None, 64) 256
_________________________________________________________________
conv_block_3 ( Conv2D) ( None, None, None, 32) 18464
_________________________________________________________________
max_pooling_3 ( MaxPooling2D) ( None, None, None, 32) 0
_________________________________________________________________
batch_normalization_2 ( Batch ( None, None, None, 32) 128
_________________________________________________________________
spatial_pyramid_pooling ( Spa ( None, 672) 0
_________________________________________________________________
dense ( Dense) ( None, 5) 3365
== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == =
Total params: 41,733
Trainable params: 41,477
Non-trainable params: 256
_________________________________________________________________
訓練程序(只貼了一部分,沒有全部貼):
Epoch 11/15
61/61 [ == == == == == == == == == == == == == == == ] - 75s 1s/step - loss: 0.6994 - acc: 0.7422 - val_loss: 0.9164 - val_acc: 0.6530
Epoch 12/15
61/61 [ == == == == == == == == == == == == == == == ] - 75s 1s/step - loss: 0.6872 - acc: 0.7438 - val_loss: 0.9210 - val_acc: 0.6515
Epoch 13/15
61/61 [ == == == == == == == == == == == == == == == ] - 75s 1s/step - loss: 0.6831 - acc: 0.7504 - val_loss: 0.8881 - val_acc: 0.6623
Epoch 14/15
61/61 [ == == == == == == == == == == == == == == == ] - 75s 1s/step - loss: 0.6232 - acc: 0.7710 - val_loss: 0.8825 - val_acc: 0.6677
Epoch 15/15
61/61 [ == == == == == == == == == == == == == == == ] - 75s 1s/step - loss: 0.6121 - acc: 0.7739 - val_loss: 0.8817 - val_acc: 0.6669
到這里所有作業就結束了~
總結
其實呢,我們再訓練的時候還是會去指定圖片的大小,只是使用Spp和GlobalAveragePooling2D的時候可以忽略一下圖片大小和全連接層的關系,直接Flatten的引數太多,如果引數在還可以接受的范圍,那就放心用吧,一開始我準備的是直接讀入圖片,不限定大小,放到網路中去訓練,但是,在讀取程序中,需要把圖片放在一個list里面,然后再把list轉換成numpy array的形似去計算,由于圖片大小不一,導致串列中每個元素的大小不一樣,轉換不了,這里呢也歡迎知道小伙伴給我指點迷津,最后 祝大家五一快樂~
參考
Flatten 和GlobalAveragePooling2D小節中的圖片來源于Google,關于SpatialPyramidPooling小節中的圖片來源于原論文
[ 1] https://zhuanlan.zhihu.com/p/79888509
[ 2] https://www.cnblogs.com/zongfa/p/9076311.html
[ 3] https://phimos.github.io/2020/07/21/RN-SPPLayer/
[ 4] https://github.com/yhenon/keras-[ 5] spp/blob/master/spp/SpatialPyramidPooling.py
[ 6] https://mermaid-js.github.io/mermaid/#/flowchart?id=interaction
[ 7] https://blog.csdn.net/u011616825/article/details/112302220
[ 8] https://zh-v2.d2l.ai/chapter_convolutional-neural-networks/lenet.html
書名:深度學習:卷積神經網路從入門到精通
書名:<< 動手深度學習>>