整個YOLOX原始碼的學習一定要按照以下順序才能整體串起來:Backbone->FPN->Head->->資料讀入原始碼->資料增強原始碼->loss計算原始碼->simOTA原始碼->demo.py腳本->train.py腳本,而該系列博文也遵循該順序來逐行分析代碼,注意是逐行,包括python語法,tensor維度和逐行代碼的作用及應用,其實網路結構本沒有任何神秘的地方,都是一些模塊堆疊起來的,你完全可以沒有任何理由的修改任何一個模塊,看完這個系列后自己完全可以隨便的去對任何網路結構做手腳,而不僅僅局限于一個調參者,
只有符合的標簽匹配策略的樣本才會定義為正樣本,只有正樣本所對應的特征圖的像素才能夠參與loss計算及反傳,所以標簽匹配策略是非常重要的,選擇合適的正樣本對精度提升至關重要,
本篇講的是YOLOX中simOTA標簽匹配策略,是YOLOX中loss反向傳播的一部分,也是YOLOX提出的新思想,simOTA是YOLOX作者在OT策略上提出的簡化(simplify)演算法,其作用是為不同目標選擇不同數量的正樣本,在分析代碼之前,首先需要對simOTA策略有清晰的認識,以下是simOTA演算法的步驟流程分解:
simOTA演算法步驟一:首先會通過get_in_boxes_info()方法確定一個正樣本的候選區域,如下圖所示:

- 一個原始影像上面有綠色框和黃色框,灰色的網格代表以FPN的其中一個stride為長度給影像打的網格,一個網格代表feature map中一個像素點所能看到的感受野,綠色框為其中的一個gtbox,黃色框為YOLOX規定的一個正方形區域,這個區域是以當前gtbox的中心點為中心,向上下左右四個方向分別延伸2.5倍的stride(特征圖對應原圖的比例),也就是說不同的特征圖上的特征點的黃色框是不一樣的,如果feature map中的一個像素點對應原圖的中心點在綠色框和黃色框的區域內,那么這個像素點就屬于YOLOX的正樣本的候選區域,
- 注意:get_in_boxes_info()該方法內所有變數都是和gtbox及原始影像相關的真實存在的變數,和網路的預測變數即bboxes_preds_per_image引數沒有任何關系,
simOTA演算法步驟二:計算get_in_boxes_info()得到的正樣本候選區域所產生的每個預測框與當前gtbox的IoU,
simOTA演算法步驟三:將計算所得的IoU按從大到小的順序排序,把排名前n_candidate_k的IoU求和,由于IoU的值不會超過1,因此這個和的值區間為 0 ~ n_candidate_k ,記這個值為dynamic_k,
simOTA演算法步驟四:計算候選區域產生的預測框與當前gtbox的cost值,得到Cost代價矩陣,該矩陣的計算公式為: ,
是平衡系數,
和
分別是一個gtbox和其預測框的分類損失和回歸損失,該矩陣代表當前gtbox和預測框之間的代價關系,預測框的cost值越小越好,通過Cost矩陣,使網路能夠自適應的找到每個gtbox的正樣本,Cost代價矩陣由三個部分組成:
- 每個真實框和當前特征點的預測框的重合程度,重合程度越高,代表這個特征點已經嘗試去擬合該真實框了,因此它的Cost代價就會越小,
- 每個真實框和當前特征點的預測框的分類精度,分類精度越高,也代表這個特征點已經嘗試去擬合該真實框了,因此它的Cost代價就會越小,
- 每個真實框的中心是否落在了特征點的一定半徑內,如果在特征點的一定半徑內,同樣代表這個特征點已經嘗試去擬合該真實框了,因此它的Cost代價就會越小,
simOTA演算法步驟五:將cost矩陣的值按從小到大的順序排列,取前dynamic_k個cost最小的預測框作為當前gtbox最終的正樣本,將其余剩下的預測框作為負樣本,對于不同的gtbox,dynamic_k的值是不一樣的,
simOTA演算法步驟六:使用求出的最終正負樣本來計算分類和回歸損失,
接著我們來到代碼分析環節,接著篇(七)的代碼,定位到yolox\models\yolo_head.py腳本的橘色框(下圖所示):

我們直接進入到該方法,輸入引數的含義如下:
- batch_idx:batchsize的索引,代表每次只取一個batchsize中的一張影像,如果batchsize = 4,則batch_idx = 0,1,2,3.
- num_gt:當前影像有多少個gtbox,
- total_num_anchors:代表FPN生成特征圖的全部像素點的個數,也即YOLOX對一張影像生成預測框的數量,值為20×20+40×40+80×80=8400.
- gt_bboxes_per_image:num_gt個gtbox的邊框值,tensor的維度為(num_gt,4),4代表gtbox的中心點坐標和長寬,
- gt_classes:num_gt個gtbox的類別,
- bboxes_preds_per_image:8400個預測框的坐標資訊,tensor的維度為(8400,4),
- expanded_strides:ferture map上每個像素點的縮放比或預測框之間的步長,tensor的維度為(1,8400),
- x_shifts,y_shifts:每個特征點在原特征圖上的X坐標和Y坐標,對80×80的特征圖而言,數值從0-79,對40×40的特征圖而言,數值從0-39,該tensor的維度為(1,8400),
- cls_preds:每個預測框的預測類別,tensor的維度為(B,8400,num_class),B代表batchsize通道,
- obj_preds:每個預測框的置信度,tensor的維度為(B,8400,1),B代表batchsize通道,
- 其他引數如:bbox_preds、imgs、labels函式內都沒有用到,
該方法首先會執行get_in_boxes_info()方法來獲取正樣本的候選區域,該方法會得到綠色框和黃色框的交集和并集:
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image,expanded_strides,x_shifts,
y_shifts,total_num_anchors,num_gt,)"""fg_mask是綠色框和黃色框的并集、is_in_boxes_and_center是綠色框和黃色框的交集"""
我們跳到該方法,下面是該方法每一行代碼的詳細注釋:
def get_in_boxes_info(self,gt_bboxes_per_image,expanded_strides,x_shifts,y_shifts,total_num_anchors,num_gt,):
expanded_strides_per_image = expanded_strides[0]"""去掉了一個維度,tensor變為(8400,)"""
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image"""得到特征圖上每個特征點對應真實影像上矩形感受野的左上角X坐標"""
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image"""得到特征圖上每個特征點對應真實影像上矩形感受野的左上角Y坐標"""
x_centers_per_image = ((x_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1))"""得到特征圖上每個特征點對應真實影像上矩形感受野的中心點X坐標"""
"""unsqueeze為增加維度方便計算,repeat為將當前資料復制標注框的個數倍(shape變為num_gt × 8400),因為每一個標注框都要與8400個預測框進行比較"""
y_centers_per_image = ((y_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1))"""得到特征圖上每個特征點對應真實影像上矩形感受野的中心點Y坐標"""
"""計算gtbox的四個邊,這樣做是因為輸入YOLOX的是gtbox的左上角坐標和長寬,而gt_bboxes_per_image已經變換成了gtbox的中心點坐標和長寬"""
gt_bboxes_per_image_l = ((gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors))"""得到gtbox左邊x坐標"""
gt_bboxes_per_image_r = ((gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors))"""得到gtbox右邊x坐標"""
gt_bboxes_per_image_t = ((gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors))"""得到gtbox上邊y坐標"""
gt_bboxes_per_image_b = ((gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors))"""得到gtbox下邊y坐標"""
"""判斷8400個中心點有哪些在gtbox內"""
b_l = x_centers_per_image - gt_bboxes_per_image_l"""特征圖上每個特征點對應真實影像上矩形框的中心點X坐標要大于gtbox左邊x坐標"""
b_r = gt_bboxes_per_image_r - x_centers_per_image"""特征圖上每個特征點對應真實影像上矩形框的中心點X坐標要小于gtbox右邊x坐標"""
b_t = y_centers_per_image - gt_bboxes_per_image_t"""特征圖上每個特征點對應真實影像上矩形框的中心點Y坐標要大于gtbox上邊Y坐標"""
b_b = gt_bboxes_per_image_b - y_centers_per_image"""特征圖上每個特征點對應真實影像上矩形框的中心點Y坐標要小于gtbox下邊Y坐標"""
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)"""將四值進行連接"""
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0"""is_in_boxes的shape是(num_gt,8400),值為True或者False,True代表特征點對應原圖的矩形框中心點在gtbox內"""
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0"""一幅圖中一共有多少個中心點在全部gtbox內,tensor為(8400)"""
center_radius = 2.5"""YOLOX的黃色區域"""
"""求黃色的框的每個邊的坐標"""
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
"""判斷8400個中心點有哪些在黃色區域內"""
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all"""求黃色框和綠色框的并集"""
is_in_boxes_and_center = (is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor])"""求黃色框和綠色框的交集"""
return is_in_boxes_anchor, is_in_boxes_and_center
接著繼續get_assignments()方法,下面是一直到bboxes_iou()方法的全部注釋,比較簡單:
def get_assignments(self,batch_idx, num_gt, total_num_anchors,gt_bboxes_per_image,gt_classes,bboxes_preds_per_image,
expanded_strides,x_shifts,y_shifts,cls_preds,bbox_preds,obj_preds,labels,imgs,mode="gpu",):
if mode == "cpu":"""如果是CPU訓練就將引數CPU化"""
print("------------CPU Mode for This Batch-------------")
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
gt_classes = gt_classes.cpu().float()
expanded_strides = expanded_strides.cpu().float()
x_shifts = x_shifts.cpu()
y_shifts = y_shifts.cpu()
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image,expanded_strides,x_shifts,y_shifts,total_num_anchors,num_gt,)"""獲取正樣本的候選區域"""
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]"""shape從之前的(8400,4)變為(在并集區域內中心點的個數,4),這樣可以大大減少計算量"""
cls_preds_ = cls_preds[batch_idx][fg_mask]"""shape從之前的(8400,num_class)變為(在并集區域內中心點的個數,num_class),這樣可以大大減少計算量"""
obj_preds_ = obj_preds[batch_idx][fg_mask]"""shape從之前的(8400,4)變為(在并集區域內中心點的個數,4),這樣可以大大減少計算量"""
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]"""獲取現在的bboxes_preds_per_image的個數"""
if mode == "cpu":"""cpu訓練的話就將引數CPU化"""
gt_bboxes_per_image = gt_bboxes_per_image.cpu()
bboxes_preds_per_image = bboxes_preds_per_image.cpu()
pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)"""計算gtbox和預測框的IOU"""
下面介紹bboxex_iou()方法,我們定位到yolox\utils\boxes.py中的bboxes_iou()方法,下面是一些簡單的注釋,可以忽略:
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
"""bboxes_a為真實框(num_gt,4),bboxes_b為預測框(在候選區域內的像素點的預測框,4),"""
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:"""如過bbox的長度不為4,就出錯啦,"""
raise IndexError
if xyxy:"""YOLOX框的形式不是xyxy,而是中心點的形式"""
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),"""真實框的左上角坐標:中心點坐標-長的一半"""
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),"""預測框的左上角坐標:中心點坐標-長的一半"""
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),"""真實框的右下角坐標:中心點坐標+長的一半"""
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),"""預測框的右下角坐標:中心點坐標+長的一半"""
)
area_a = torch.prod(bboxes_a[:, 2:], 1)"""計算所有真實框的面積"""
area_b = torch.prod(bboxes_b[:, 2:], 1)"""計算所有預測框的面積"""
en = (tl < br).type(tl.type()).prod(dim=2)""""""
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)
接著bboxes_iou()方法往下走,首先給出幾個loss的計算公式:
邊界框loss的計算公式:
![]()
類別loss的計算公式:

cost代價矩陣的計算公式,默認lambda = 3.0:
![]()
給出loss計算的代碼注釋:
"""真實框在每幅影像的方格中one_hot向量,輸出tensor維度為(gtbox的個數,在并集區域內中心點的個數,類別數)"""
gt_cls_per_image = (F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1))
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)"""邊界框損失,iou 越大,匹配度越高,所以需要取負號"""
if mode == "cpu":
cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
cls_preds_ = (cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_())"""類別預測的sigmoid * 置信度預測的sigmoid = 類別分數"""
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)"""二值交叉熵計算類別綜合loss值"""
del cls_preds_
"""構造cost矩陣,"""
cost = (pair_wise_cls_loss+ 3.0 * pair_wise_ious_loss+ 100000.0 * (~is_in_boxes_and_center))"""其中100000.0*(~is_in_boxes_and_center )指正樣本取反,剩下的都是負樣本,一方面需要最小化正樣本的損失,同時意味著需要最大化負樣本的損失,"""
"""cost值越小,表示匹配度越高"""
(num_fg,gt_matched_classes,pred_ious_this_matching,matched_gt_inds,) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
接著進入dynamic_k_matching()方法,其就是YOLOX的標簽匹配策略,其引數含義如下:
- cost:通過回歸損失和類別損失計算得到的cost,
- pair_wise_ious:全部的gtbox和全部預測框的IoU,
- gt_classes:每一個gtbox對應的類別,
- num_gt:gt的數量,
- fg_mask:綠色框和黃色框的交集,
該方法具體注釋如下:
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
matching_matrix = torch.zeros_like(cost)"""生成和cost維度一樣的矩陣"""
ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))"""取10個或者不大于10,一會要把把排名前n_candidate_k的IoU求和"""
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)"""topk為從大到小的排序,并取前n_candidate_k的IoU,維度為(num_gt,10),即每個gtbox都取自己排名前10的IoU"""
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)"""獲取每一個gtbox的正樣本個數,clamp是區間函式,每一個目標保證必須有一個正樣本,因此不能小于1"""
for gt_idx in range(num_gt):"""給每個gtbox都這樣做"""
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)"""選取正樣本"""
matching_matrix[gt_idx][pos_idx] = 1.0"""找到cost最小的位置,然后設定候選框矩陣對應位置為1"""
del topk_ious, dynamic_ks, pos_idx"""為了節約記憶體,釋放這幾個引數"""
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:"""為了防止一個正樣本對應兩個真實框"""
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)"""比較兩個真實框誰的cost小就作為正樣本,另外一個舍去"""
matching_matrix[:, anchor_matching_gt > 1] *= 0.0"""將大于1的那一列的所有數先全變為0"""
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0"""將cost最小的位置變為1"""
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
num_fg = fg_mask_inboxes.sum().item()"""獲取正樣本的個數"""
fg_mask[fg_mask.clone()] = fg_mask_inboxes"""8400中有哪些是正樣本"""
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)"""每個正樣本對應的真實框的索引"""
gt_matched_classes = gt_classes[matched_gt_inds]"""每個正樣本對應的真實類別"""
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]"""每個正樣本與真實框對應的IoU"""
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
有了dynamic_k_matching()方法回傳的正負樣本,就可以計算后面的loss,回到yolo_head.py中的get_losses()方法,接著進行:
torch.cuda.empty_cache()"""清除顯存,釋放空間"""
num_fg += num_fg_img"""總的正樣本個數"""
cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes) * pred_ious_this_matching.unsqueeze(-1)"""得到num_class個類別對應的IoU"""
obj_target = fg_mask.unsqueeze(-1)"""以正樣本的位置為置信度"""
reg_target = gt_bboxes_per_image[matched_gt_inds]"""框的目標"""
if self.use_l1:
l1_target = self.get_l1_target(
outputs.new_zeros((num_fg_img, 4)),
gt_bboxes_per_image[matched_gt_inds],
expanded_strides[0][fg_mask],
x_shifts=x_shifts[0][fg_mask],
y_shifts=y_shifts[0][fg_mask],
)
cls_targets.append(cls_target)"""把batchsize個圖的正樣本資訊進行拼接"""
reg_targets.append(reg_target)"""把batchsize個圖的正樣本資訊進行拼接"""
obj_targets.append(obj_target.to(dtype))"""把batchsize個圖的正樣本資訊進行拼接"""
fg_masks.append(fg_mask)"""把batchsize個圖的正樣本資訊進行拼接"""
if self.use_l1:
l1_targets.append(l1_target)
cls_targets = torch.cat(cls_targets, 0)
reg_targets = torch.cat(reg_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
fg_masks = torch.cat(fg_masks, 0)
if self.use_l1:
l1_targets = torch.cat(l1_targets, 0)
num_fg = max(num_fg, 1)"""總的正樣本個數"""
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() / num_fg).sum() / num_fg"""計算IoU的loss"""
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum() / num_fg"""二元交叉熵損失"""
if self.use_l1:
loss_l1 = (self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
else:
loss_l1 = 0.0
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1"""總損失"""
return (loss,reg_weight * loss_iou,loss_obj,loss_cls,loss_l1,num_fg / max(num_gts, 1),)
代碼部分到這里就算結束了,但YOLOX代碼往深了挖還是需要花很多時間的,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/427475.html
標籤:AI
上一篇:R語言匯入資料檔案(資料匯入、加載、讀取)、使用haven包的read_sav函式匯入SPSS中的sav格式檔案
下一篇:2022年美賽C題思路分享+翻譯
