LeNet模型對CIFAR-10資料集分類【pytorch】
- 目錄
- LeNet 網路模型
- CIFAR-10 資料集
- Pytorch 實作代碼
目錄
本文為針對CIFAR-10資料集的基于簡單神經網路LeNet分類實作(pytorch實作)
LeNet 網路模型

Tip:(以上為原始LeNet)為了更好的效果,我在模型實作時此處將池化層換為Max
CIFAR-10 資料集
CIFAR-10資料集由60000張32x32的彩色影像組成,分為10類,每類有6000張影像,有50000張訓練影像和10000張測驗影像,
該資料集被分為五個訓練批和一個測驗批,每個批有10000張影像,測驗批包含從每個類中隨機選擇的1000張影像,訓練批包含其余的隨機順序的影像,但有些訓練批可能包含一個類別的影像多于另一個,在它們之間,訓練批次恰好包含了每個類別的5000張影像,
下面是資料集中的類別,以及每個類別的10張隨機影像,

關于資料集更多詳情請見:CIFAR-10資料集官方說明
Pytorch 實作代碼
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
class Lenet5(nn.Module):
def __init__(self,input_channels):
super().__init__()
#第1個卷積層
self.conv1 = nn.Conv2d(input_channels , 6 , kernel_size = 5 , padding = 2)
#第1個池化層
self.pooling1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
#第2個卷積層
self.conv2= nn.Conv2d(6 , 16 , kernel_size=5)
#第2個池化層
self.pooling2 = nn.MaxPool2d(kernel_size = 2, stride=2)
##最后的三個FC
self.Flatten = nn.Flatten()
# 計算得出的當前的前面處理過后的shape,當然也可print出來以后再確定
self.Linear1 = nn.Linear(16*6*6,120)
self.Linear2 = nn.Linear(120,84)
self.Linear3 = nn.Linear(84,10)
def forward(self,X):
''' 前向推導 '''
X = self.pooling1(F.relu(self.conv1(X)))
X = self.pooling2(F.relu(self.conv2(X)))
X = X.view(X.size()[0],-1)
X = F.relu(self.Linear1(X))
X = F.relu(self.Linear2(X))
X = F.relu(self.Linear3(X))
return X
def load_CIFAR10(batch_size, resize=None):
""" 加載資料集到記憶體 """
trans = [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.CIFAR10(root="dataset",
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.CIFAR10(root="dataset",
train=False,
transform=trans,
download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=2),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=2))
def get_labels(labels):
''' 標簽轉換 '''
text_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
return [text_labels[int(i)] for i in labels]
def train(loss,updater,train_iter,net,epoches):
''' 訓練模型 '''
for epoch in range(epoches):
run_loss = 0
for step,(X,y) in enumerate(train_iter):
if torch.cuda.is_available():
X = X.cuda()
y = y.cuda()
y_hat = net.forward(X) #前向推導
ls = loss(y_hat,y).sum() #計算誤差
updater.zero_grad() #梯度清零
ls.backward() #計算新的梯度
run_loss += ls.item()
updater.step() #更新權值
print( f'true:{y} preds:{y_hat.argmax(axis=1)} epoch:{epoch:02d}\t epoch_loss {run_loss/5000}\t ')
print('finished training\n')
def predict(net,test_iter,n=6):
''' 測驗集預測 '''
for X, y in test_iter:
if torch.cuda.is_available():
X = X.cuda()
y = y.cuda()
trues = get_labels(y)
preds = get_labels(net(X).argmax(axis=1))
titles = ['groundTruth :'+true + ' ' +'preds: '+ pred for true, pred in zip(trues, preds)]
print(titles[0:n])
if __name__ == '__main__':
#設定超引數
batch_size, learning_rate, epoches = 10, 0.05, 1
#加載資料
trainSet,testSet = load_CIFAR10(batch_size)
#加載模型
net = Lenet5(3)
if torch.cuda.is_available():
net.cuda()
# 選擇損失函式
loss = nn.CrossEntropyLoss()
# 優化器
updater = torch.optim.SGD(net.parameters(), lr=learning_rate)
#訓練
train(loss,updater,trainSet,net,batch_size,epoches,learning_rate)
#測驗集預測
predict(net,testSet)
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/289473.html
標籤:AI
上一篇:人工智能基礎-數學知識之線性代數
下一篇:案例分享:Qt+Arm基于RV1126平臺的內窺鏡軟硬整套解決方案(實時影像、凍結、拍照、錄像、背光調整、硬體光源調整,其他產品也可使用該平臺,如視頻監控,物聯網產品等等)
