CIFAR-10資料集應用:快速入門資料增強方法Mixup,顯著提升影像識別準確度
作者|Ta-Ying Cheng,牛津大學博士研究生,Medium技術博主,多篇文章均被平臺官方刊物Towards Data Science收錄
翻譯|頌賢
深度學習蓬勃發展的這幾年來,影像分類一直是最為火熱的領域之一,傳統上的影像識別嚴重依賴像是擴張/侵蝕或者是頻域變換這樣的處理方法,但特征提取的困難性限制了這些方法的進步空間,
現如今的神經網路則顯著提高了影像識別的準確率,因為神經網路能夠尋找輸入影像和輸出標簽之間的關系,并以此不斷地調整它的識別策略,
然而,神經網路往往需要大量的資料進行訓練,而優質的訓練資料并不是唾手可得的,因此現在許多人都在研究如何能夠實作所謂的資料增強(Data augmentation),即在一個已有的小資料集中憑空增加資料量,來達到以一敵百的效果,
本文就將帶大家認識一種簡單而有效的資料增強策略Mixup,并介紹直接在PyTorch中實作Mixup的方法,
為什么需要資料增強?
神經網路架構內的引數是根據給定的資料進行訓練和更新的,但由于訓練資料只覆寫了某一部分可能資料的分布情況,網路很可能就會在分布的“能見”部分過度擬合,
因此,我們擁有的訓練資料越多,理論上就越能覆寫整個分布的情況,這也正是為什么以資料為中心的AI(data-centric AI)非常重要,當然,在資料量有限的情況下,我們也并不是沒有辦法,通過資料增強,我們就可以嘗試通過微調原有資料的方式產生新資料,并將其作為“新”樣本送入網路進行訓練,
什么是Mixup?
圖1:Mixup的簡易演示圖
假設我們現在要做的事情是給貓和狗的圖片做分類,并且我們已經有了一組標注好了是貓是狗的資料(例如[1, 0] -> 狗, [0, 1] -> 貓),那么Mixup簡單來說就是將兩張影像及其標簽平均化為一個新資料,
具體而言,我們可以用數學公式寫出Mixup的概念:
x
=
λ
x
i
+
(
1
?
λ
)
(
x
j
)
,
y
=
λ
y
i
+
(
1
?
λ
)
(
y
j
)
,
x = \lambda x_i + ( 1 - \lambda ) (x_j),\\ y = \lambda y_i + ( 1 - \lambda ) (y_j),
x=λxi?+(1?λ)(xj?),y=λyi?+(1?λ)(yj?),
其中,x和y分別是混合xi(標簽為y?)和x?(標簽為y?)后的影像和標簽,而λ則是從給定的貝塔分布中取得的亂數,
由此,Mixup能夠為我們提供不同資料類別之間的連續資料樣本,并因此直接擴大了給定訓練集的分布,從而使網路在測驗階段更加強大,
Mixup的萬用性
Mixup其實只是一種資料增強方法,它和任何用于分類的網路架構都是正交的,也就是說,我們可以在任何要進行分類任務的網路中對相應的資料集使用Mixup方法,
Mixup的提出者張宏毅等人基于其最初發表的論文《Mixup: Beyond Empirical Risk Minimization》對多個資料集和架構進行了實驗,發現了Mixup在神經網路之外的應用中也能體現其強大能力,
計算環境
庫
我們將通過PyTorch(包括torchvision)來構建整個程式,Mixup需要的從beta分布中生成的樣本,我們可以從NumPy庫中獲得,我們還將使用random來為Mixup尋找隨機影像,下面的代碼能夠匯入我們需要的所有庫:
"""
Import necessary libraries to train a network using mixup
The code is mainly developed using the PyTorch library
"""
import numpy as np
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
資料集
為了演示,我們將用傳統的影像分類任務來說明Mixup的強大,那么這種情況下CIFAR-10則會是非常理想的資料集,CIFAR-10包含10個類別的60000張彩色影像(每類6000張),按5:1的比例分為訓練和測驗集,這些影像分類起來相當簡單,但比最基本的數字識別資料集MNIST要難一些,
有許多方法可以下載CIFAR-10資料集,比如多倫多大學網站里就包含了相關資料集,在這里,我推薦大家使用格物鈦的公開資料集平臺,因為在這個平臺上,如果使用他們的SDK,不用下載也可以獲取免費的資料集資源,
事實上,這個公開資料集平臺包含了行業內數百個知名的優質資料集,每個資料集都有相關的作者說明,以及不同訓練任務的標簽,例如分類或目標檢測,當然,大家也可以在這個平臺下載其他分類資料集,如CompCars或SVHN,來測驗Mixup在不同場景下的性能,

硬體要求
一般來說,我們最好用GPU(顯卡)來訓練神經網路,因為它能顯著提高訓練速度,不過如果只有CPU可用,我們還是可以對程式進行簡單測驗的,如果你想讓程式能夠自行確定所需硬體,使用以下代碼即可:
"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實作
網路
這里,我們的目標是要測驗Mixup的性能,而不是除錯網路本身,所以我們只需要簡單實作一個4層卷積層和2層全連接層的卷積神經網路(CNN)即可,為了比較使用和不使用Mixup的區別,我們將應用同一個網路來確保比較的準確性,
我們可以使用下列代碼來搭建上面所說的簡單網路:
"""
Create a simple CNN
"""
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Network consists of 4 convolutional layers followed by 2 fully-connected layers
self.conv11 = nn.Conv2d(3, 64, 3)
self.conv12 = nn.Conv2d(64, 64, 3)
self.conv21 = nn.Conv2d(64, 128, 3)
self.conv22 = nn.Conv2d(128, 128, 3)
self.fc1 = nn.Linear(128 * 5 * 5, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv11(x))
x = F.relu(self.conv12(x))
x = F.max_pool2d(x, (2,2))
x = F.relu(self.conv21(x))
x = F.relu(self.conv22(x))
x = F.max_pool2d(x, (2,2))
# Size is calculated based on kernel size 3 and padding 0
x = x.view(-1, 128 * 5 * 5)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return nn.Sigmoid()(x)
Mixup
Mixup階段是在資料集加載程序中完成的,所以我們必須寫入我們自己的資料集,而不是使用torchvision.datasets所提供的默認資料集,
下面的代碼簡單地實作了Mixup,并結合使用了NumPy的貝塔函式,
"""
Dataset and Dataloader creation
All data are downloaded found via Graviti Open Dataset which links to CIFAR-10 official page
The dataset implementation is where mixup take place
"""
class CIFAR_Dataset(Dataset):
def __init__(self, data_dir, train, transform):
self.data_dir = data_dir
self.train = train
self.transform = transform
self.data = []
self.targets = []
# Loading all the data depending on whether the dataset is training or testing
if self.train:
for i in range(5):
with open(data_dir + 'data_batch_' + str(i+1), 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
else:
with open(data_dir + 'test_batch', 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
self.targets.extend(entry['labels'])
# Reshape it and turn it into the HWC format which PyTorch takes in the images
# Original CIFAR format can be seen via its official page
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Create a one hot label
label = torch.zeros(10)
label[self.targets[idx]] = 1.
# Transform the image by converting to tensor and normalizing it
if self.transform:
image = transform(self.data[idx])
# If data is for training, perform mixup, only perform mixup roughly on 1 for every 5 images
if self.train and idx > 0 and idx%5 == 0:
# Choose another image/label randomly
mixup_idx = random.randint(0, len(self.data)-1)
mixup_label = torch.zeros(10)
label[self.targets[mixup_idx]] = 1.
if self.transform:
mixup_image = transform(self.data[mixup_idx])
# Select a random number from the given beta distribution
# Mixup the images accordingly
alpha = 0.2
lam = np.random.beta(alpha, alpha)
image = lam * image + (1 - lam) * mixup_image
label = lam * label + (1 - lam) * mixup_label
return image, label
需要注意的是,我們并沒有對所有的影像都進行Mixup,而是大概每5張處理1張,我們還使用了一個0.2的貝塔分布,你可以自己為不同的實驗改變分布以及被混合的影像的數量,或許你會取得更好的結果!
訓練和評估
下面的代碼展示的是訓練程序,我們將批次大小設定為128,學習率為1e-3,總次數為30次,整個訓練進行了兩次,唯一區別是有沒有使用Mixup,需要注意的是, 損失函式需要由我們自己定義,因為目前BCE損失不允許使用帶有小數的標簽,
"""
Initialize the network, loss Adam optimizer
Torch BCE Loss does not support mixup labels (not 1 or 0), so we implement our own
"""
net = CNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
def bceloss(x, y):
eps = 1e-6
return -torch.mean(y * torch.log(x + eps) + (1 - y) * torch.log(1 - x + eps))
best_Acc = 0
"""
Training Procedure
"""
for epoch in range(NUM_EPOCHS):
net.train()
# We train and visualize the loss every 100 iterations
for idx, (imgs, labels) in enumerate(train_dataloader):
imgs = imgs.to(device)
labels = labels.to(device)
preds = net(imgs)
loss = bceloss(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if idx%100 == 0:
print("Epoch {} Iteration {}, Current Loss: {}".format(epoch, idx, loss))
# We evaluate the network after every epoch based on test set accuracy
net.eval()
with torch.no_grad():
total = 0
numCorrect = 0
for (imgs, labels) in test_dataloader:
imgs = imgs.to(device)
labels = labels.to(device)
preds = net(imgs)
numCorrect += (torch.argmax(preds, dim=1) == torch.argmax(labels, dim=1)).float().sum()
total += len(imgs)
acc = numCorrect/total
print("Current image classification accuracy at epoch {}: {}".format(epoch, acc))
if acc > best_Acc:
best_Acc = acc
為了評估Mixup的效果,我們進行了三次對照試驗來計算最終的準確性,在沒有Mixup的情況下,該網路在測驗集上的準確率約為74.5%,而在使用了Mixup的情況下,準確率提高到了約76.5%!
影像分類之外
Mixup將影像分類的準確性帶到了一個前所未有的高度,但研究表明,Mixup的好處還能延伸到其他計算機視覺任務中,比如對抗性資料的生成和防御,另外也有相關文獻在Mixup拓展到三維表示中,目前的結果表明Mixup在這一領域也十分有效的,例如PointMixup,
結語
由此,我們用Mixup做的小實驗就大功告成啦!在這篇文章中,我們簡單介紹了Mixup的概念并演示了如何在影像分類網路訓練中應用Mixup,完整的實作方式可以在這—GitHub倉庫中找到,
【關于格物鈦】:
格物鈦智能科技定位為面向機器學習的資料平臺,致力于為 AI 開發者打造下一代新型基礎設施,從根本上改變其與非結構化資料的互動方式,我們通過非結構化資料管理工具TensorBay和開源資料集社區Open Datasets,幫助機器學習團隊和個人降低資料獲取、存盤和處理成本,加速 AI開發和產品創新,為人工智能賦能千行百業、驅動產業升級提供堅實基礎,
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/295032.html
標籤:其他
上一篇:貪吃蛇(C語言實作)
