機器學習筆記(3)CNN
參考網址:https://blog.csdn.net/out_of_memory_error/article/details/81434907
https://blog.csdn.net/cxmscb/article/details/71023576
基本概念學習
輸入層
對于黑白圖片,輸入層為nm,對于RGB圖片,輸入層為3nm,(nm為圖片大小)
卷積
卷積是指使用卷積核進行特征提取和特征映射的程序,具有區域連接和共享權重的特點,
幾個關鍵詞:
感受視野(local receptive fields):一個像素對應回原圖的區域大小,這里為卷積核的大小3*3
深度(depth):卷積核的維度
步長(stride):卷積核每次在圖上移動的步長
填充值(zero-padding):是不是要在外圈加0

卷積就是對原圖上和卷積核大小相同的區域做矩陣內積,對于上述圖片,輸入為3通道77的矩陣,卷積核深度為2,大小為33,即卷積核為2(深度)*3(每維個數)33,卷積結果加上偏執(3[每維個數]33)就得到輸出(2[深度]33),
池化
圖片來自知乎,

池化也稱為下采樣,可以增大感受野,池化有平均池化,最大值池化,
經典網路LeNet-5學習

0.輸入為一張13232的黑白圖片,
1.第一層為卷積C1,卷積核為6155(深度,每維個數,長,寬),步長為1,輸出為62828.
2.第二層為池化S2,pool核為1622,步長為2,輸出為61414.
3.第三層為卷積C3,卷積核為16655,步長為1,輸出為161010
4.第四層為池化S4,pool核為11622,步長為2,輸出為1655
5.第五層為卷積C5,卷積核為120655,卷積核大小與輸入影像大小相同,輸出為12011
6.第六層為全連接層F6,權重引數有12084個,閾值引數有84個,輸出為84
7.第七層為輸出層,引數有10*(84+1)個,輸出為10
CNN實戰:手寫數字識別MNIST資料集
環境
python 3.6
pytorch 1.6.0
cuda 10.2
CNN代碼
import torch
import time
import numpy as np
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
###CNN網路結構
class CNN(nn.Module):
def __init__(self):
# 初始化
super(CNN,self).__init__()
#卷積1*28*28->25*26*26
self.layer1 = nn.Sequential(
nn.Conv2d(1,25,kernel_size=3),
nn.BatchNorm2d(25),
nn.ReLU(inplace=True)
)
#卷積25*26*26 -> 50*24*24
self.layer2 = nn.Sequential(
nn.Conv2d(25,50,kernel_size=3),
nn.BatchNorm2d(50),
nn.ReLU(inplace=True)
)
#池化50*24*24 -> 50*12*12
self.layer3 = nn.Sequential(
nn.MaxPool2d(kernel_size= 2 ,stride=2)
)
#卷積50*12*12 -> 100*10*10
self.layer4 = nn.Sequential(
nn.Conv2d(50,100,kernel_size=3),
nn.BatchNorm2d(100),
nn.ReLU(inplace = True)
)
#池化100*10*10 -> 100*5*5
self.layer5 = nn.Sequential(
nn.MaxPool2d(kernel_size=2,stride=2)
)
#全連接層
self.layer6 = nn.Sequential(
nn.Linear(100*5*5,1024),
nn.ReLU(inplace=True),
nn.Linear(1024,128),
nn.ReLU(inplace=True),
nn.Linear(128,10),
)
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0),-1) #將張量reshape為一維
x = self.layer6(x)
return (x)
###MNIST資料集下載
#定義資料引數
batch_size = 100
#影像預處理:Totensor歸一化到(0,1),Normalize標準化到(-1,1)
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
#使用dataset類下載
dataset_train = datasets.MNIST(root='./MNIST',
train=True,transform=data_tf,download=True)
dataset_test = datasets.MNIST(root='./MNIST',
train=True,transform=data_tf)
#使用Dataloader加載資料(shuffle表示是否打亂資料順序)
train_loader = DataLoader(dataset_train,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset_test,batch_size=batch_size,shuffle=True)
###定義和初始化
learning_rate = 0.01
epoches_num = 3
#使用GPU
CNN_model = CNN()
if torch.cuda.is_available():
CNN_model = CNN_model.cuda()
#損失函式和優化器
lossfunction = nn.CrossEntropyLoss()
optimizer = optim.SGD(CNN_model.parameters(),lr=learning_rate)
###訓練
epoch = 0
loss_log = []
for i in range(epoches_num):
start = time.clock()
for data in train_loader:
img,lable = data
#變數初始化
if torch.cuda.is_available():
img = img.cuda()
lable = lable.cuda()
else:
img = Variable(img)
lable = Variable(lable)
output = CNN_model.forward(img)
#計算LOSS
loss = lossfunction(output,lable)
loss_print = loss.data.item()
#梯度置零
optimizer.zero_grad()
#反向傳播
loss.backward()
optimizer.step()
end = time.clock()
epoch+=1
print('第',epoch,'輪耗時',end-start)
print('epoch:{},loss:{}'.format(epoch, loss_print))
loss_log.append(loss_print)
#繪制損失函式曲線
plt.figure(1)
plt.plot(epoch, loss_log)
plt.show()
###測驗
CNN_model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:
img_test,lable_test = data
if torch.cuda.is_available():
img_test = img_test.cuda()
lable_test = lable_test.cuda()
output_test = CNN_model.forward(img_test)
loss_test = lossfunction(output_test,lable_test)
eval_loss += loss_test.data.item()*lable_test.size(0)
_, pred = torch.max(output_test, 1)
num_correct = (pred == lable_test).sum()
eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(
eval_loss / (len(dataset_test)),
eval_acc / (len(dataset_test))
))
訓練集損失函式曲線:

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/153714.html
標籤:其他
