本文展示了如何使用MATLAB訓練Faster R-CNN目標檢測器,實作對車輛的檢測,本例使用一個包含295張影像的小標記資料集,每個影像包含一個或兩個已標記的車輛目標,一個小的資料集對于探索 Faster R-CNN 訓練程序是有用的,但在實踐中,需要更多的標注影像來訓練一個魯棒的檢測器,
文章目錄
- 一、資料集準備
- 1、資料集下載
- 2、資料集加載
- 3、資料集劃分
- 4、創建資料存盤
- 二、創建Faster R-CNN網路
- 1、初始化引數
- 2、創建網路
- 三、訓練網路
- 1、設定訓練選項
- 2、加載模型訓練
- 四、車輛目標檢測
一、資料集準備
1、資料集下載
本資料集中的許多影像來自加州理工學院1999年和2001年的汽車資料集,可以在加州理工學院計算視覺網站上獲得,由Pietro Perona創建并經許可使用,
% 下載一個預先訓練過的檢測器,以避免等待訓練完成,如果您想訓練檢測器,請將doTraining變數設定為true,
doTraining = false;
if ~doTraining && ~exist('fasterRCNNResNet50EndToEndVehicleExample.mat','file')
disp('Downloading pretrained detector (118 MB)...');
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/fasterRCNNResNet50EndToEndVehicleExample.mat';
websave('fasterRCNNResNet50EndToEndVehicleExample.mat',pretrainedURL);
end
2、資料集加載
下面進行解壓縮車輛影像,加載資料集:
unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;
車輛資料存盤在一個兩串列中,其中第一列包含影像檔案路徑,第二列包含車輛包圍框,
3、資料集劃分
將資料集分解為訓練集、驗證集和測驗集,選擇60%的資料用于訓練,10%用于驗證,其余的用于測驗訓練過的檢測器,
rng(0)
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * height(vehicleDataset));
trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);
4、創建資料存盤
使用imageDatastore和boxLabelDatastore創建用于在培訓和評估期間加載影像和標簽資料的資料存盤,
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,'vehicle'));
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(validationDataTbl(:,'vehicle'));
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(testDataTbl(:,'vehicle'));
組合影像和標簽資料存盤:
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
下面來展示資料集中的一張圖片和對應的標簽框:
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'Rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
原圖和標簽如下:

二、創建Faster R-CNN網路
Faster R-CNN目標檢測網路由特征提取網路和兩個子網路組成,特征提取網路通常是預先訓練的CNN,如ResNet-50或Inception v3,在特征提取網路之后的第一個子網路是區域建議網路(RPN),訓練它生成候選區域——影像中可能存在目標的區域,第二個子網路被訓練來預測每個物件提議的實際類別,
特征提取網路通常是預先訓練的CNN,本例使用ResNet-50進行特征提取,您還可以使用其他預先訓練過的網路,如MobileNet v2或ResNet-18,這取決于您的應用程式需求,
1、初始化引數
使用fasterRCNNLayers創建Faster R-CNN網路,需要你指定幾個引數:
1)確定網路輸入圖片的大小
inputSize = [224 224 3];
2)指定錨框大小和個數
numAnchors = 3;
anchorBoxes = [29 17; 46 39; 136 116];
3)確定特征提取網路
featureExtractionNetwork = resnet50;
在使用 ResNet-50 之前,需要打開附加功能資源管理器,并點擊安裝Deep Learning Toolbox Model for ResNet-50
Network,

2、創建網路
選擇’activation_40_relu’作為特征提取層,該特征提取層輸出特征映射,向下采樣的因子為16,這種向下采樣量是空間解析度和提取特征強度之間的一個很好的權衡,因為進一步向下提取的特征以空間解析度為代價編碼更強的影像特征,選擇最優的特征提取層需要實證分析,您可以使用analyzeNetwork來查找網路中其他潛在特征提取層的名稱,
% 特征提取層
featureLayer = 'activation_40_relu';
% 定義類別
numClasses = width(vehicleDataset)-1;
% 創建網路
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
三、訓練網路
1、設定訓練選項
使用trainingOptions指定網路訓練選項,設定’ValidationData’為預處理的驗證資料,將CheckpointPath設定為臨時位置,這使得在訓練程序中能夠保存部分訓練過的檢測器,如果訓練被中斷,例如斷電或系統故障,您可以從保存的檢查點恢復訓練,
options = trainingOptions('sgdm',...
'MaxEpochs',10,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-3,...
'CheckpointPath',tempdir,...
'ValidationData',validationData);
2、加載模型訓練
如果doTraining為true,則使用trainFasterRCNNObjectDetector訓練Fast R-CNN物件檢測器,否則,加載預訓練的網路,
if doTraining
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
else
% Load pretrained detector for the example.
pretrained = load('fasterRCNNResNet50EndToEndVehicleExample.mat');
detector = pretrained.detector;
end
四、車輛目標檢測
為了檢查訓練效果,我們讀取一張測驗影像,并運行訓練好的檢測器,
% 讀取圖片
I = imread(testDataTbl.imageFilename{3});
% 將影像調整為與訓練影像相同的大小
I = imresize(I,inputSize(1:2));
% 運行檢測器
[bboxes,scores] = detect(detector,I);
% 顯示效果
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
顯示檢測結果:

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/316355.html
標籤:AI
上一篇:我怎樣才能有2個輸入if陳述句?
