作者|Samrat Saha
編譯|VK
來源|Towards Datas Science
Supervised Contrastive Learning這篇論文在有監督學習、交叉熵損失與有監督對比損失之間進行了大量的討論,以更好地實作影像表示和分類任務,讓我們深入了解一下這篇論文的內容,
論文指出可以在image net資料集有1%的改進,

就架構而言,它是一個非常簡單的網路resnet 50,具有128維的頭部,如果你想,你也可以多加幾層,

Code
self.encoder = resnet50()
self.head = nn.Linear(2048, 128)
def forward(self, x):
feat = self.encoder(x)
#需要對128向量進行標準化
feat = F.normalize(self.head(feat), dim=1)
return feat
如圖所示,訓練分兩個階段進行,
-
使用對比損失的訓練集(兩種變化)
-
凍結引數,然后使用softmax損失在線性層上學習分類器,(來自論文的做法)
以上是不言自明的,
本文的主要內容是了解自監督的對比損失和監督的對比損失,

從上面的SCL(監督對比損失)圖中可以看出,貓與任何非貓進行對比,這意味著所有的貓都屬于同一個標簽,都是正數對,任何非貓都是負的,這與三元組資料以及triplet loss的作業原理非常相似,
每一張貓的圖片都會被放大,所以即使是從一張貓的圖片中,我們也會有很多貓,
監督對比損失的損失函式,雖然看起來很可怕,但其實很簡單,

稍后我們將看到一些代碼,但首先是非常簡單的解釋,每個z是標準化的128維向量,
也就是說||z||=1
重申一下線性代數中的事實,如果u和v兩個向量正規化,意味著u.v=cos(u和v之間的夾角)
這意味著如果兩個標準化向量相同,它們之間的點乘=1
#嘗試理解下面的代碼
import numpy as np
v = np.random.randn(128)
v = v/np.linalg.norm(v)
print(np.dot(v,v))
print(np.linalg.norm(v))
損失函式假設每幅影像都有一個增強版本,每批有N幅影像,生成的batch大小= 2*N
在i!=j,yi=yj時,分子exp(zi.zj)/tau表示一批中所有的貓,將i個第128個dim向量zi與所有的j個第128個dim向量點積,

分母是i個貓的影像點乘其他不是貓的影像,取zi和zk的點,使i!=k表示它點乘除它自己以外的所有影像,

最后,我們取對數概率,并將其與批處理中除自身外的所有貓影像相加,然后除以2*N-1

所有影像的總損失和

我們使用一些torch代碼可以理解上面的內容,
假設我們的批量大小是4,讓我們看看如何計算單個批次的損失,
如果批量大小為4,你在網路上的輸入將是8x3x224x224,在這里影像的寬度和高度為224,
8=4x2的原因是我們對每個影像總是有一個對比度,因此需要相應地撰寫一個資料加載程式,
對比損失resnet將輸出8x128維的矩陣,你可以分割這些維度以計算批量損失,
#batch大小
bs = 4
這個部分可以計算分子

temperature = 0.07
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
temperature)

我們的特征形狀是8x128,讓我們采取3x128矩陣和轉置,下面是可視化后的圖片,

anchor_feature=3x128和contrast_feature=128x3,結果為3x3,如下所示

如果你注意到所有的對角線元素都是點本身,這實際上我們不想要,我們將洗掉他們,
線性代數有個性質:如果u和v是兩個向量,那么當u=v時,u.v是最大的,因此,在每一行中,如果我們取錨點對比度的最大值,并且取相同值,則所有對角線將變為0,
讓我們把維度從128降到2
#bs 1 和 dim 2 意味著 2*1x2
features = torch.randn(2, 2)
temperature = 0.07
contrast_feature = features
anchor_feature = contrast_feature
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
temperature)
print('anchor_dot_contrast=\n{}'.format(anchor_dot_contrast))
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
print('logits_max = {}'.format(logits_max))
logits = anchor_dot_contrast - logits_max.detach()
print(' logits = {}'.format(logits))
#輸出看看對角線發生了什么
anchor_dot_contrast=
tensor([[128.8697, -12.0467],
[-12.0467, 50.5816]])
logits_max = tensor([[128.8697],
[ 50.5816]])
logits = tensor([[ 0.0000, -140.9164],
[ -62.6283, 0.0000]])
創建人工標簽和創建適當的掩碼進行對比計算,這段代碼有點復雜,所以要仔細檢查輸出,
bs = 4
print('batch size', bs)
temperature = 0.07
labels = torch.randint(4, (1,4))
print('labels', labels)
mask = torch.eq(labels, labels.T).float()
print('mask = \n{}'.format(logits_mask))
#對它進行硬編碼,以使其更容易理解
contrast_count = 2
anchor_count = contrast_count
mask = mask.repeat(anchor_count, contrast_count)
#屏蔽self-contrast的情況
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(bs * anchor_count).view(-1, 1),
0
)
mask = mask * logits_mask
print('mask * logits_mask = \n{}'.format(mask))
讓我們理解輸出,
batch size 4
labels tensor([[3, 0, 2, 3]])
#以上的意思是在這批4個品種的葡萄中,我們有3,0,2,3個標簽,以防你們忘了我們在這里只做了一次對比所以我們會有3_c 0_c 2_c 3_c作為輸入批處理中的對比,
mask =
tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
[1., 0., 1., 1., 1., 1., 1., 1.],
[1., 1., 0., 1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1., 1., 1., 1.],
[1., 1., 1., 1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1., 0., 1., 1.],
[1., 1., 1., 1., 1., 1., 0., 1.],
[1., 1., 1., 1., 1., 1., 1., 0.]])
#這是非常重要的,所以我們創建了mask = mask * logits_mask,它告訴我們在第0個影像表示中,它應該與哪個影像進行對比,
# 所以我們的標簽就是標簽張量([[3,0,2,3]])
# 我重新命名它們是為了更好地理解張量([[3_1,0_1,2_1,3_2]])
mask * logits_mask =
tensor([[0., 0., 0., 1., 1., 0., 0., 1.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0.],
[1., 0., 0., 0., 1., 0., 0., 1.],
[1., 0., 0., 1., 0., 0., 0., 1.],
[0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 1., 1., 0., 0., 0.]])

錨點對比代碼
logits = anchor_dot_contrast — logits_max.detach()
損失函式

數學回顧


我們已經有了第一部分的點積除以tau作為logits,
#上述等式的第二部分等于torch.log(exp_logits.sum(1, keepdim=True))
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# 計算對數似然的均值
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# 損失
loss = - mean_log_prob_pos
loss = loss.view(anchor_count, 4).mean()
print('19. loss {}'.format(loss))
我認為這是監督下的對比損失,我認為現在很容易理解自監督的對比損失,因為它比這更簡單,
根據本文的研究結果,contrast_count越大,模型越清晰,需要修改contrast_count為2以上,希望你能在上述說明的幫助下嘗試,
參考參考
- [1] : Supervised Contrastive Learning
- [2] : Florian Schroff, Dmitry Kalenichenko, and James Philbin. Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 815–823, 2015.
- [3] : A Simple Framework for Contrastive Learning of Visual Representations, Ting Chen, Simon Kornblith Mohammad Norouzi, Geoffrey Hinton
- [4] : https://github.com/google-research/simclr
原文鏈接:https://towardsdatascience.com/a-detailed-study-of-self-supervised-contrastive-loss-and-supervised-contrastive-loss-906f2f27796f
歡迎關注磐創AI博客站:
http://panchuang.net/
sklearn機器學習中文官方檔案:
http://sklearn123.com/
歡迎關注磐創博客資源匯總站:
http://docs.panchuang.net/
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/152087.html
標籤:其他
上一篇:Python中的字典
下一篇:Python中的字典
