作者|GUEST
編譯|VK
來源|Analytics Vidhya
介紹
SimCLR論文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解釋了這個框架如何從更大的模型和更大的批處理中獲益,并且如果有足夠的計算能力,可以產生與監督模型類似的結果,
但是這些需求使得框架的計算量相當大,如果我們可以擁有這個框架的簡單性和強大功能,并且有更少的計算需求,這樣每個人都可以訪問它,這不是很好嗎?Moco-v2前來救援,

注意:在之前的一篇博文中,我們在PyTorch中實作了SimCLR框架,它是在一個包含5個類別的簡單資料集上實作的,總共只有1250個訓練影像,
資料集
這次我們將在Pytorch中在更大的資料集上實作Moco-v2,并在Google Colab上訓練我們的模型,這次我們將使用Imagenette和Imagewoof資料集

來自Imagenette資料集的一些影像

這些資料集的快速摘要(更多資訊在這里:https://github.com/fastai/imagenette):
-
Imagenette由Imagenet的10個容易分類的類組成,總共有9479個訓練影像和3935個驗證集影像,
-
Imagewoof是一個由Imagenet提供的10個難分類組成的資料集,因為所有的類都是狗的品種,總共有9035個訓練影像,3939個驗證集影像,
對比學習
對比學習在自我監督學習中的作用是基于這樣一個理念:我們希望同一類別中不同的影像觀具有相似的表征,但是,由于我們不知道哪些影像屬于同一類別,通常所做的是將同一影像的不同外觀的表示拉近,我們把這些不同的外觀稱為正對(positive pairs),

另外,我們希望不同類別的影像有不同的外觀,使它們的表征彼此遠離,不同影像的不同外觀的呈現與類別無關,會被彼此推開,我們把這些不同的外觀稱為負對(negative pairs),

在這種情況下,一個影像的前景是什么?前景可以被認為是以一種經過修改的方式看待影像的某些部分,它本質上是影像的一種變換,
根據手頭的任務,有些轉換可以比其他轉換作業得更好,SimCLR表明,應用隨機裁剪和顏色抖動可以很好地完成各種任務,包括影像分類,這本質上來自于網格搜索,從旋轉、裁剪、剪切、噪聲、模糊、Sobel濾波等選項中選擇一對變換,
從外觀到表示空間的映射是通過神經網路完成的,通常,resnet用于此目的,下面是從影像到表示的管道

負對是如何產生的?
在同一幅影像中,由于隨機裁剪,我們可以得到多個表示,這樣,我們就可以產生正對,
但是如何生成負對呢?負對是來自不同影像的表示,SimCLR論文在同一批中創建了這些,如果一個批包含N個影像,那么對于每個影像,我們將得到2個表示,這總共占2*N個表示,對于一個特定的表示x,有一個表示與x形成正對(與x來自同一個影像的表示),其余所有表示(正好是2*N–2)與x形成負對,
如果我們手頭有大量的負樣本,這些表示就會得到改善,但是,在SimCLR中,只有當批量較大時,才能實作大量的負樣本,這導致了對計算能力的更高要求,MoCo-v2提供了生成負樣本的另一種方法,讓我們詳細了解一下,
動態詞典
我們可以用一種稍微不同的方式來看待對比學習方法,即將查詢與鍵進行匹配,我們現在有兩個編碼器,一個用于查詢,另一個用于鍵,此外,為了得到大量的負樣本,我們需要一個大的鍵編碼字典,

此背景關系中的正對表示查詢與鍵匹配,如果查詢和鍵都來自同一個影像,則它們匹配,編碼的查詢應該與其匹配的鍵相似,而與其他查詢不同,
對于負對,我們維護一個大字典,其中包含以前批處理的編碼鍵,它們作為查詢的負樣本,我們以佇列的形式維護字典,新的batch被入隊,較早的batch被出列,通過更改此佇列的大小,可以更改負采樣數,
這種方法的挑戰
-
隨著鍵編碼器的更改,在稍后時間點排隊的鍵可能與較早排隊的鍵不一致,為了使用對比學習方法,與查詢進行比較的所有鍵必須來自相同或相似的編碼器,這樣比較才會有意義且一致,
-
另一個挑戰是,使用反向傳播學習編碼器引數是不可行的,因為這將需要計算佇列中所有樣本的梯度(這將導致大的計算圖),
為了解決這兩個問題,MoCo將鍵編碼器實作為基于動量的查詢編碼器的移動平均值[1],這意味著它以這種方式更新關鍵編碼器引數:

其中m非常接近于1(例如,典型值為0.999),這確保我們在不同的時間從相似的編碼器獲得編碼鍵,
損失函式-InfoNCE
我們希望查詢接近其所有正樣本,遠離所有負樣本,InfoNC函式E會捕獲它,它代表資訊噪聲對比估計,對于查詢q和鍵k,InfoNCE損失函式是:

我們可以重寫為:

當q和k的相似性增大,q與負樣本的相似性減小時,損失值減小
以下是損失函式的代碼:
τ = 0.05
def loss_function(q, k, queue):
# N是批量大小
N = q.shape[0]
# C是表示的維數
C = q.shape[1]
# bmm代表批處理矩陣乘法
# 如果mat1是b×n×m張量,那么mat2是b×m×p張量,
# 然后輸出一個b×n×p張量,
pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))
# 在查詢和佇列張量之間執行矩陣乘法
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)
# 求和
denominator = neg + pos
return torch.mean(-torch.log(torch.div(pos,denominator)))

讓我們再看看這個損失函式,并將它與分類交叉熵損失函式進行比較,

這里pred?是資料點在第i類中的概率值預測,true?是該點屬于第i類的實際概率值(可以是模糊的,但大多數情況下是一個one-hot),
如果你不熟悉這個話題,你可以看這個視頻來更好地理解交叉熵,另外,請注意,我們經常通過softmax這樣的函式將分數轉換為概率值:https://www.youtube.com/watch?v=ErfnhcEV1O8
我們可以把資訊損失函式看作交叉熵損失,資料樣本“q”的正確類是第r類,底層分類器基于softmax,它試圖在K+1類之間進行分類,
Info-NCE還與編碼表示之間的相互資訊有關;關于這一點的更多細節見[4],
MoCo-v2框架
現在,讓我們把所有的東西放在一起,看看整個Moco-v2演算法是什么樣子的,
步驟1:
我們必須得到查詢和鍵編碼器,最初,鍵編碼器具有與查詢編碼器相同的引數,它們是彼此的復制品,隨著訓練的進行,鍵編碼器將成為查詢編碼器的移動平均值(在這一點上進展緩慢),
由于計算能力的限制,我們使用Resnet-18體系結構來實作,在通常的resnet架構之上,我們添加了一些密集的層,以使表示的維數降到25,這些層中的某些層稍后將充當投影,
# 定義我們的深度學習架構
resnetq = resnet18(pretrained=False)
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(resnetq.fc.in_features, 100)),
('added_relu1', nn.ReLU(inplace=True)),
('fc2', nn.Linear(100, 50)),
('added_relu2', nn.ReLU(inplace=True)),
('fc3', nn.Linear(50, 25))
]))
resnetq.fc = classifier
resnetk = copy.deepcopy(resnetq)
# 將resnet架構遷移到設備
resnetq.to(device)
resnetk.to(device)
步驟2:
現在,我們已經有了編碼器,并且假設我們已經設定了其他重要的資料結構,現在是時候開始訓練回圈并理解管道了,
這一步是從訓練批中獲取編碼查詢和鍵,我們用L2范數對表示進行規范化,
只是一個約定警告,所有后續步驟中的代碼都將位于批處理和epoch回圈中,我們還將張量“k”從它的梯度中分離出來,因為我們不需要計算圖中的鍵編碼器部分,因為動量更新方程會更新鍵編碼器,
# 梯度零化
optimizer.zero_grad()
# 檢索xq和xk這兩個影像batch
xq = sample_batched['image1']
xk = sample_batched['image2']
# 把它們移到設備上
xq = xq.to(device)
xk = xk.to(device)
# 獲取他們的輸出
q = resnetq(xq)
k = resnetk(xk)
k = k.detach()
# 將輸出規范化,使它們成為單位向量
q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1))
k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
步驟3:
現在,我們將查詢、鍵和佇列傳遞給前面定義的loss函式,并將值存盤在一個串列中,然后,像往常一樣,對損失值呼叫backward函式并運行優化器,
# 獲得損失值
loss = loss_function(q, k, queue)
# 把這個損失值放到epoch損失串列中
epoch_losses_train.append(loss.cpu().data.item())
# 反向傳播
loss.backward()
# 運行優化器
optimizer.step()
步驟4:
我們將最新的batch加入我們的佇列,如果我們的佇列大小大于我們定義的最大佇列大小(K),那么我們就從其中取出最老的batch,可以使用torch.cat進行佇列操作,
# 更新佇列
queue = torch.cat((queue, k), 0)
# 如果佇列大于最大佇列大小(k),則出列
# batch大小是256,可以用變數替換
if queue.shape[0] > K:
queue = queue[256:,:]
步驟5:
現在我們進入訓練回圈的最后一步,即更新鍵編碼器,我們使用下面的for回圈來實作這一點,
# 更新resnet
for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()):
θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
一些訓練細節
訓練resnet-18模型的Imagenette和Imagewoof資料集的GPU時間接近18小時,為此,我們使用了googlecolab的GPU(16GB),我們使用的batch大小為256,tau值為0.05,學習率為0.001,最終降低到1e-5,權重衰減為1e-6,我們的佇列大小為8192,鍵編碼器的動量值為0.999,
結果
前3層(將relu視為一層)定義了投影頭,我們將其移除用于影像分類的下游任務,在剩下的網路上,我們訓練了一個線性分類器,
我們得到了64.2%的正確率,而使用10%的標記訓練資料,使用MoCo-v2,相比之下,使用最先進的監督學習方法,其準確率接近95%,
對于Imagewoof,我們對10%的標記資料得到了38.6%的準確率,在這個資料集上進行對比學習的效果低于我們的預期,我們懷疑這是因為首先,資料集非常困難,因為所有類都是狗類,
其次,我們認為顏色是這些類的一個重要的區別特征,應用顏色抖動可能會導致來自不同類的多個影像彼此混合表示,相比之下,監督方法的準確率接近90%,
能夠彌合自監督模型和監督模型之間差距的設計變更:
-
使用更大更寬的模型,
-
通過使用更大的批量和字典大小,
-
使用更多的資料,如果可以的話,同時引入所有未標記的資料,
-
在大量資料上訓練大型模型,然后提取它們,
一些有用的鏈接:
-
谷歌Colab:https://colab.research.google.com/drive/1AepjEbcHPw2Z-xY8iJkvou-Njnn0VZmd?usp=sharing
-
Imagewoof Github倉庫結果:https://github.com/thunderInfy/mocov2-imagewoof-results
-
Imagenette Github倉庫結果:https://github.com/thunderInfy/simclr-with-momentum
-
Imagewoof資料集鏈接:https://github.com/thunderInfy/imagewoof
-
Imagenette資料集鏈接:https://github.com/thunderInfy/imagenette
參考參考
- Momentum Contrast for Unsupervised Visual Representation Learning, Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick(https://arxiv.org/pdf/1911.05722.pdf)
- Improved Baselines with Momentum Contrastive Learning, Xinlei Chen, Haoqi Fan, Ross Girshick, and Kaiming He(https://arxiv.org/pdf/2003.04297.pdf)
- A simple framework for contrastive learning of visual representations, Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton.(https://arxiv.org/pdf/2002.05709.pdf)
- Representation Learning with Contrastive Predictive Coding, Aaron van den Oord, Yazhe Li, and Oriol Vinyals(https://arxiv.org/pdf/1807.03748.pdf)
原文鏈接:https://www.analyticsvidhya.com/blog/2020/08/moco-v2-in-pytorch/
歡迎關注磐創AI博客站:
http://panchuang.net/
sklearn機器學習中文官方檔案:
http://sklearn123.com/
歡迎關注磐創博客資源匯總站:
http://docs.panchuang.net/
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/189516.html
標籤:其他
上一篇:TF2目標檢測API
