主頁 >  其他 > 卷積神經網路訓練CIFAR100

卷積神經網路訓練CIFAR100

2021-01-18 11:05:40 其他

一、前言

1.1、資料集介紹

? CIFAR-10資料集由10個類的60000個32x32彩色影像組成,每個類有6000個影像,有50000個訓練影像和10000個測驗影像,
? 資料集分為五個訓練批次和一個測驗批次,每個批次有10000個影像,測驗批次包含來自每個類別的恰好1000個隨機選擇的影像,訓練批次以隨機順序包含剩余影像,但一些訓練批次可能包含來自一個類別的影像比另一個更多,總體來說,五個訓練集之和包含來自每個類的正好5000張影像,
? 以下是資料集中的類,以及來自每個類的10個隨機影像:
在這里插入圖片描述

? CIFAR-100 資料集就像CIFAR-10,除了它有100個類,每個類包含600個影像,,每類各有500個訓練影像和100個測驗影像,CIFAR-100 中的100個類被分成20個超類,每個影像都帶有一個“精細”標簽(它所屬的類)和一個“粗糙”標簽(它所屬的超類),以下是 CIFAR-100 中的類別串列:

在這里插入圖片描述

1.2、本次網路介紹

? 本次共搭建18個網路層:10個卷積層、5個最大池化層以及3個全連接層,其中2個卷積層+1個最大池化層為1個單元,具體網路引數如下圖所示:

在這里插入圖片描述

二、匯入庫

? 本次使用的Tensorflow版本為2.0.0,為方便搭建網路,利用 kears 構建網路,因此,匯入 kears 的 layers、Sequential、optimizers、datasets,

import tensorflow as tf
from tensorflow.keras import layers,Sequential,optimizers,datasets

三、匯入Cifar-100資料集并查看資料集格式

? 利用 kearas.datasets 函式直接匯入資料集,并分為訓練集與測驗集,并查看資料集形狀:

(x,y),(x_test,y_test) = datasets.cifar100.load_data() # 加載資料集
print('x:',x.shape,'y:',y.shape,'x_test:',x_test.shape,"y_test:",y_test.shape)

? 得到的資料格式如下:

在這里插入圖片描述

? 因為 y 與 y_test 都是標簽集,因此,其維度均應為1維,而加載后的卻是二維張量,所以要進行資料處理,

y = tf.reshape(y,[50000])
y_test = tf.reshape(y_test,[10000])
print('y:',y.shape,'y_test:',y_test.shape)

在這里插入圖片描述

四、資料處理

4.1、資料預處理函式

? 通過資料與處理函式,將 numpy 型別資料變為 tensorflow 型別資料,

def preprocess(x,y):                      # 預處理函式
    # [0-1]
    x = tf.cast(x,dtype=tf.float32) / 255 # 歸一化處理
    y = tf.cast(y,dtype=tf.int64)         # y必須是整數型,因為onehot里輸入不能為float
    return x,y

? 通過資料預處理函式便可以將別的型別的資料變為 tensorflow 型別,因為像素均為 0 -255 值,所以進行歸一化處理,將資料值分布變為 0 -1,由后面要進行 one_hot 編碼,所以標簽集資料必須為 tfint64型別,

4.2、進行資料處理

train_db = tf.data.Dataset.from_tensor_slices((x,y)) # 將輸入的張量的第一個維度看做樣本的個數,沿其第一個維度將tensor切片,得到的每個切片是一個樣本資料,實作了輸入張量的自動切片,
train_db = train_db.shuffle(10000).map(preprocess).batch(64) # 64個樣本為一個 batch
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(64)

? 首先對于訓練集,對變數進行切片處理,然后利用 shuffle 函式打亂資料,利用與處理函式進行資料轉換,并64個樣本為一個batch;而對于測驗集,并不需要來打亂樣本順序,

五、搭建網路層

? 為便于區分,本次搭建分為兩步:一是卷積層——包括卷積與池化兩個操作,選擇最大池化操作;二是全連接層,最后一層輸出神經元為100(本次分類的類數)且最后一層不添加激活函式,

5.1、卷積層搭建

? 卷積層共 5 個單元,每個單元都是“2+1”組合:2個卷積層,一個池化層,具體形式如下:

conv_layers = [
    # unit 1
    layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 2
    layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 3
    layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 4
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 5
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same')

]

5.2、全連接層搭建

? 全連接層共三層,其引數如下:

fc_net = Sequential([
        layers.Dense(256, activation=tf.nn.relu),     # 輸出為 256 
        layers.Dense(128, activation=tf.nn.relu),	  # 輸出為 128
        layers.Dense(100, activation=None),			  # 輸出為 100,且不加激活函式
    ])

六、設立各部分網路輸入形式、優化器以及定義所有變數

? 因為照片是 32 × 32 × 3 32\times32\times3 32×32×3 的,因此第一部分的輸入是 [ N o n e , 32 , 32 , 3 ] [None,32,32,3] [None,32,32,3] ,其中None 代表樣本數;第二部分全連接層輸入為 [ N o n e , 512 ] [None,512] [None,512] ,因此,為保證輸入格式正確,需要對第一部分網路的輸出進行 r e s h a p e reshape reshape 操作,

? 選擇 A d a m Adam Adam 優化器,學習率設定為 1 e ? 4 1e-4 1e?4 ,網路的總變數 = 第一部分變數 + 第二部分變數,

conv_net.build(input_shape=[None, 32, 32, 3])
fc_net.build(input_shape=[None,512])

optimizer = optimizers.Adam(lr=1e-4)
variables = conv_net.variables + fc_net.variables

七、前向傳播與誤差計算

? 訓練 50 次, b a t c h s i z e = 64 batchsize = 64 batchsize=64 ,每訓練 100 個批次,輸出列印準確率與損失函式值,

for epoch in range(50):
    for step,(x,y) in enumerate(train_db):        # enumerate() 函式用于將一個可遍歷的資料物件(如串列、元組或字串)組合為一個索引序列,同時列出資料和資料下標,一般用在 for 回圈當中
        with tf.GradientTape() as tape:
            # [b,32,32,3] => [b,1,1,512]
            out = conv_net(x)
            # flaten
             out = tf.reshape(out,[-1,512])
            # [b,512] => [b,100]
             logits = fc_net(out)
            # onehot編碼:[b] => [b,100]
            y_onehotcode = tf.one_hot(y,100) # y不能為float
            # compute loss
            loss = tf.losses.categorical_crossentropy(y_onehotcode,logits,from_logits=True)
            loss = tf.reduce_mean(loss)

八、梯度求解與權重更新

grads = tape.gradient(loss,variables)           # 求解梯度
optimizer.apply_gradients(zip(grads,variables)) # 梯度更新
if step % 100 ==0:
    print(epoch,step,'loss:',float(loss))

九、測驗訓練集

for x,y in test_deb:
    out = conv_net(x)
    out = tf.reshape(out,[-1,512]) # 保證全連接層輸入正確
	logits = fc_net(out)
    prob = tf.nn.softmax(logits,axis=1) 
    pred = tf.argmax(prob,axis= 1)	# 把全連接層的輸出進行分類
    # 求解準確率
    correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
    total_num += x.shape[0]
    total_correct += int(correct)
    
acc = total_correct / total_num
print(epoch,'acc:',acc)

十、附錄

10.1、注意事項

  • 資料集的形狀十分重要,無論是加載后資料集還是要預處理的資料集,都應確保其 s h a p e shape shape 準確,否則無法代入網路進行訓練;
  • 利用 tf.one_hot() 函式進行編碼時,要確保資料為 tf.int64

10.2、完整代碼

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # 屏蔽tensorflow的輸出日志資訊,必須放在匯入TF庫前


import tensorflow as tf
from tensorflow.keras import layers,Sequential,optimizers,datasets

conv_layers = [
    # unit 1
    layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 2
    layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 3
    layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 4
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),

    # unit 5
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same')

]

def preprocess(x,y):  # 預處理函式
    # [0-1]
    x = tf.cast(x,dtype=tf.float32) / 255
    y = tf.cast(y,dtype=tf.int64)         # y必須是整數型,因為onehot里輸入不能為float
    return x,y

(x,y),(x_test,y_test) = datasets.cifar100.load_data() # 加載資料集
y = tf.reshape(y,[50000])
y_test = tf.reshape(y_test,[10000])

train_db = tf.data.Dataset.from_tensor_slices((x,y)) # 將輸入的張量的第一個維度看做樣本的個數,沿其第一個維度將tensor切片,得到的每個切片是一個樣本資料,實作了輸入張量的自動切片,
train_db = train_db.shuffle(10000).map(preprocess).batch(64)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(64)

def main():
    # [b,32,32,3] => [b,1,1,512]
    conv_net = Sequential(conv_layers)
    # x =tf.random.normal([4,32,32,3])
    # out = conv_net(x)
    # print(out.shape)
    fc_net = Sequential([
        layers.Dense(256,activation=tf.nn.relu),
        layers.Dense(128,activation=tf.nn.relu),
        layers.Dense(100,activation=None)             # 最后一個全連接層,因為有100個種類,所以輸出為100
    ])
    conv_net.build(input_shape=[None, 32, 32, 3])
    fc_net.build(input_shape=[None,512])

    optimizer = optimizers.Adam(lr=1e-4)
    variables = conv_net.variables + fc_net.variables

    for epoch in range(50):
        for step,(x,y) in enumerate(train_db):        # enumerate() 函式用于將一個可遍歷的資料物件(如串列、元組或字串)組合為一個索引序列,同時列出資料和資料下標,一般用在 for 回圈當中
            with tf.GradientTape() as tape:
                # [b,32,32,3] => [b,1,1,512]
                out = conv_net(x)
                # flaten
                out = tf.reshape(out,[-1,512])
                # [b,512] => [b,100]
                logits = fc_net(out)
                # onehot編碼:[b] => [b,100]
                y_onehotcode = tf.one_hot(y,100) # y不能為float
                # compute loss
                loss = tf.losses.categorical_crossentropy(y_onehotcode,logits,from_logits=True)
                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss,variables)           # 求解梯度
            optimizer.apply_gradients(zip(grads,variables)) # 梯度更新 zip() 函式用于將可迭代的物件作為引數,將物件中對應的元素打包成一個個元組,然后回傳由這些元組組成的串列,
                                                            # 如果各個迭代器的元素個數不一致,則回傳串列長度與最短的物件相同,

            if step % 100 == 0:
                print(epoch,step,'loss:',float(loss))


        total_num = 0
        total_correct = 0

        for x,y in test_db:
            out = conv_net(x)
            out = tf.reshape(out,[-1,512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits,axis=1) # 進行預測
            pred = tf.argmax(prob,axis=1)
            correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
            total_num += x.shape[0]
            total_correct += int(correct)

        acc = total_correct / total_num
        print(epoch,'acc:',acc)

if __name__ == '__main__':
    main()

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/250169.html

標籤:AI

上一篇:【python】在學習用于圖和網路分析的python時遇到的問題和解決方法

下一篇:MySQL常用陳述句

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more