主頁 > 移動端開發 > PyTorch_簡單神經網路搭建_MNIST資料集

PyTorch_簡單神經網路搭建_MNIST資料集

2020-10-27 03:32:02 移動端開發

今天用PyTorch參考《Python深度學習基于PyTorch》搭建了一個簡單的神經網路,在這里做一下筆記,

首先附上PyTorch中文檔案鏈接,下面的各介面函式在這里面基本都能查到,寶藏檔案,對于像我一樣的新手菜鳥特別友好,強推!!!
PyTorch中文檔案鏈接

正文開始:
這是本次搭建神經網路的結構圖

此網路包含兩個隱藏層,激活函式都為relu函式,最后用torch.max(out,1)找出張量out最大值索引作為預測值,

下面不廢話了,直接代碼實作

1. 先匯入必要的模塊

import numpy as np
import torch

#匯入PyTorch內置的mnist資料
from torchvision.datasets import mnist

#匯入預處理模塊
from torchvision import transforms
from torch.utils.data import DataLoader

#匯入神經網路工具
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader:
該介面主要用來將自定義的資料讀取介面的輸出或者PyTorch已有的資料讀取介面的輸入按照batch size封裝成Tensor,

2. 定義超引數

#定義后面要用到的超引數
train_batch_size = 64
test_batch_size = 128

#學習率與訓練次數
learning_rate = 0.01
nums_epoches = 20

#優化器的時候使用的引數
lr = 0.1
momentum = 0.5

batch_size:相當于每次匯入訓練的樣本量大小(相比較于一次匯入完,一次匯入一張,需要設定一個合適的量)一般高級演算法要注意設定量,簡單神經網路不用太過在意,

3.下載資料并對資料進行預處理

#用compose來定意預處理函式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])

#下載資料,在工程檔案夾里新建一個data檔案夾儲存下載的資料
train_dataset = mnist.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform, target_transform=None, download=False)

#資料加載器,組合資料集和采樣器,并在資料集上提供單行程或多行程迭代器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

把預處理需要用到的東西組合在Compose里面,
transforms.ToTensor()是把一個取值范圍是[0,255]的PIL.Image或者shape為(H,W,C)的numpy.ndarray,轉換成形狀為[C,H,W],取值范圍是[0,1.0]的torch.FloadTensor,(這句話我的理解是把資料格式轉換成網路里可以使用的資料格式)
transforms.Normalize則是將灰度影像正則化,

4.可視化資料

import matplotlib.pyplot as plt
%matplotlib inline
examples = enumerate(test_loader)
batch_idx,(example_data,example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0],cmap='gray',interpolation='none')
    plt.title("Ground Truth:{}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])

資料可視化這部分我沒有仔細看,直接把代碼列出,以后如果之后再用到的時候我再單獨寫一個筆記,
下面是可視化后的結果:
這是可視化后的mnist資料

5.構建模型

class CNN(nn.Module):
    def __init__(self,in_dim,hidden_1,hidden_2,out_dim):
        super(CNN,self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, hidden_1, bias=True),nn.BatchNorm1d(hidden_1))
        self.layer2 = nn.Sequential(nn.Linear(hidden_1,hidden_2,bias=True),nn.BatchNorm1d(hidden_2))
        self.layer3 = nn.Sequential(nn.Linear(hidden_2,out_dim))
        
    def forward(self,x):
    	#注意 F 與 nn 下的激活函式使用起來不一樣的
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x

class torch.nn.Sequential(* args):一個時序容器,Modules 會以他們傳入的順序被添加到容器中,當然,也可以傳入一個,

**注:**我敲的時候還不知道CNN 跟簡單神經網路的區別,所以把這個類名定義為CNN了,大家在實作的時候可以定義為Net,

6.實體化網路

#實體化網路,只考慮使用CPU
model = CNN(28*28,300,100,10)
#定義損失函式和優化器
criterion = nn.CrossEntropyLoss()
#momentum:動量因子有什么用處?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)

class torch.nn.CrossEntropyLoss(weight=None, size_average=True):此標準將LogSoftMax和NLLLoss集成到一個類中,當訓練一個多類分類器的時候,這個方法是十分有用的,
動量因子的作用后面會了我再來修改!

7.訓練模型

#開始訓練 先定義存盤損失函式和準確率的陣列
losses = []
acces = []
#測驗用
eval_losses = []
eval_acces = []

for epoch in range(nums_epoches):
    #每次訓練先清零
    train_loss = 0
    train_acc = 0
    #將模型設定為訓練模式
    model.train()
    #動態學習率
    if epoch%5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.1
    for img,label in train_loader:
        #例如 img=[64,1,28,28] 做完view()后變為[64,1*28*28]
        #把圖片資料格式轉換成與網路匹配的格式
        img = img.view(img.size(0),-1)
        #前向傳播,將圖片資料傳入模型中
        out = model(img)
        loss = criterion(out,label)
        #反向傳播
        #optimizer.zero_grad()意思是把梯度置零,也就是把loss關于weight的導數變成0
        optimizer.zero_grad()
        loss.backward()
        #這個方法會更新所有的引數,一旦梯度被如backward()之類的函式計算好后,我們就可以呼叫這個函式
        optimizer.step()
        
        #記錄誤差 
        train_loss += loss.item()
        
        #計算分類的準確率,找到概率最大的下標
        _,pred = out.max(1)
        num_correct = (pred == label).sum().item()#記錄標簽正確的個數
        acc = num_correct/img.shape[0]
        train_acc += acc
    losses.append(train_loss/len(train_loader))
    acces.append(train_acc/len(train_loader))
    
    eval_loss = 0
    eval_acc = 0
    model.eval()
    for img,label in test_loader:
        img = img.view(img.size(0),-1)
        
        out = model(img)
        loss = criterion(out,label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        eval_loss += loss.item()
        
        _,pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct/img.shape[0]
        eval_acc += acc
    eval_losses.append(eval_loss/len(test_loader))
    eval_acces.append(eval_acc/len(test_loader))
    

    print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'
             .format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),
                    eval_loss/len(test_loader),eval_acc/len(test_loader)))

這是訓練后的輸出:
epoch:0,Train Loss:0.3494,Train Acc:0.9190,Test Loss:0.1510,Test Acc:0.9550
epoch:1,Train Loss:0.1290,Train Acc:0.9644,Test Loss:0.1037,Test Acc:0.9687
epoch:2,Train Loss:0.0882,Train Acc:0.9756,Test Loss:0.0848,Test Acc:0.9744
epoch:3,Train Loss:0.0676,Train Acc:0.9818,Test Loss:0.0686,Test Acc:0.9778
epoch:4,Train Loss:0.0535,Train Acc:0.9853,Test Loss:0.0569,Test Acc:0.9824
epoch:5,Train Loss:0.0385,Train Acc:0.9906,Test Loss:0.0308,Test Acc:0.9906
epoch:6,Train Loss:0.0345,Train Acc:0.9920,Test Loss:0.0306,Test Acc:0.9911
epoch:7,Train Loss:0.0321,Train Acc:0.9930,Test Loss:0.0301,Test Acc:0.9916
epoch:8,Train Loss:0.0324,Train Acc:0.9931,Test Loss:0.0293,Test Acc:0.9919
epoch:9,Train Loss:0.0304,Train Acc:0.9937,Test Loss:0.0288,Test Acc:0.9921
epoch:10,Train Loss:0.0302,Train Acc:0.9935,Test Loss:0.0282,Test Acc:0.9925
epoch:11,Train Loss:0.0294,Train Acc:0.9937,Test Loss:0.0274,Test Acc:0.9929
epoch:12,Train Loss:0.0289,Train Acc:0.9938,Test Loss:0.0274,Test Acc:0.9931
epoch:13,Train Loss:0.0294,Train Acc:0.9941,Test Loss:0.0274,Test Acc:0.9930
epoch:14,Train Loss:0.0286,Train Acc:0.9944,Test Loss:0.0280,Test Acc:0.9925
epoch:15,Train Loss:0.0289,Train Acc:0.9939,Test Loss:0.0279,Test Acc:0.9924
epoch:16,Train Loss:0.0287,Train Acc:0.9939,Test Loss:0.0277,Test Acc:0.9925
epoch:17,Train Loss:0.0290,Train Acc:0.9937,Test Loss:0.0272,Test Acc:0.9929
epoch:18,Train Loss:0.0295,Train Acc:0.9938,Test Loss:0.0277,Test Acc:0.9924
epoch:19,Train Loss:0.0285,Train Acc:0.9942,Test Loss:0.0275,Test Acc:0.9932

8.可視化訓練及測驗損失值

plt.title('trainloss')
plt.plot(np.arange(len(losses)),losses)
plt.legend(['Train Loss'],loc='upper right')

可視化訓練次數于損失函式值的關系:
在這里插入圖片描述
end
第一次在CSDN上寫筆記,希望可以堅持,慢慢成長,

轉載請註明出處,本文鏈接:https://www.uj5u.com/yidong/192995.html

標籤:其他

上一篇:Python實作愷撒密碼(8行代碼)

下一篇:6000字長文,帶你用Python完成 “Excel合并(拆分)” 的各種操作!

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 【從零開始擼一個App】Dagger2

    Dagger2是一個IOC框架,一般用于Android平臺,第一次接觸的朋友,一定會被搞得暈頭轉向。它延續了Java平臺Spring框架代碼碎片化,注解滿天飛的傳統。嘗試將各處代碼片段串聯起來,理清思緒,真不是件容易的事。更不用說還有各版本細微的差別。 與Spring不同的是,Spring是通過反射 ......

    uj5u.com 2020-09-10 06:57:59 more
  • Flutter Weekly Issue 66

    新聞 Flutter 季度調研結果分享 教程 Flutter+FaaS一體化任務編排的思考與設計 詳解Dart中如何通過注解生成代碼 GitHub 用對了嗎?Flutter 團隊分享如何管理大型開源專案 插件 flutter-bubble-tab-indicator A Flutter librar ......

    uj5u.com 2020-09-10 06:58:52 more
  • Proguard 常用規則

    介紹 Proguard 入口,如何查看輸出,如何使用 keep 設定入口以及使用實體,如何配置壓縮,混淆,校驗等規則。

    ......

    uj5u.com 2020-09-10 06:59:00 more
  • Android 開發技術周報 Issue#292

    新聞 Android即將獲得類AirDrop功能:可向附近設備快速分享檔案 谷歌為安卓檔案管理應用引入可安全隱藏資料的Safe Folder功能 Android TV新主界面將顯示電影、電視節目和應用推薦內容 泄露的Android檔案暗示了傳說中的谷歌Pixel 5a與折疊屏新機 谷歌發布Andro ......

    uj5u.com 2020-09-10 07:00:37 more
  • AutoFitTextureView Error inflating class

    報錯: Binary XML file line #0: Binary XML file line #0: Error inflating class xxx.AutoFitTextureView 解決: <com.example.testy2.AutoFitTextureView android: ......

    uj5u.com 2020-09-10 07:00:41 more
  • 根據Uri,Cursor沒有獲取到對應的屬性

    Android: 背景:呼叫攝像頭,拍攝視頻,指定保存的地址,但是回傳的Cursor檔案,只有名稱和大小的屬性,沒有其他諸如時長,連ID屬性都沒有 使用 cursor.getInt(cursor.getColumnIndexOrThrow(MediaStore.Video.Media.DURATIO ......

    uj5u.com 2020-09-10 07:00:44 more
  • Android連載29-持久化技術

    一、持久化技術 我們平時所使用的APP產生的資料,在記憶體中都是瞬時的,會隨著斷電、關機等丟失資料,因此android系統采用了持久化技術,用于存盤這些“瞬時”資料 持久化技術包括:檔案存盤、SharedPreference存盤以及資料庫存盤,還有更復雜的SD卡記憶體儲。 二、檔案存盤 最基本存盤方式, ......

    uj5u.com 2020-09-10 07:00:47 more
  • Android Camera2Video整合到自己專案里

    背景: Android專案里呼叫攝像頭拍攝視頻,原本使用的 MediaStore.ACTION_VIDEO_CAPTURE, 后來因專案需要,改成了camera2 1.Camera2Video 官方demo有點問題,下載后,不能直接整合到專案 問題1.多次拍攝視頻崩潰 問題2.雙擊record按鈕, ......

    uj5u.com 2020-09-10 07:00:50 more
  • Android 開發技術周報 Issue#293

    新聞 谷歌為Android TV開發者提供多種新功能 Android 11將自動填表功能整合到鍵盤輸入建議中 谷歌宣布Android Auto即將支持更多的導航和數字停車應用 谷歌Pixel 5只有XL版本 搭載驍龍765G且將比Pixel 4更便宜 [圖]Wear OS將迎來重磅更新:應用啟動時間 ......

    uj5u.com 2020-09-10 07:01:38 more
  • 海豚星空掃碼投屏 Android 接收端 SDK 集成 六步驟

    掃碼投屏,開放網路,獨占設備,不需要額外下載軟體,微信掃碼,發現設備。支持標準DLNA協議,支持倍速播放。視頻,音頻,圖片投屏。好點意思。還支持自定義基于 DLNA 擴展的操作動作。好像要收費,沒體驗。 這里簡單記錄一下集成程序。 一 跟目錄的build.gradle添加私有mevan倉庫 mave ......

    uj5u.com 2020-09-10 07:01:43 more
最新发布
  • 歡迎頁輪播影片

    如圖,引導開始,球從上落下,同時淡入文字,然后文字開始輪播,最后一頁時停止,點擊進入首頁。 在來看看效果圖。 重力球先不講,主要歡迎輪播簡單實作 首先新建一個類 TextTranslationXGuideView,用于影片展示 文本是類似的,最后會有個圖片箭頭影片,布局很簡單,就是一個 TextVi ......

    uj5u.com 2023-04-20 08:40:31 more
  • 【FAQ】關于華為推送服務因營銷訊息頻次管控導致服務通訊類訊息

    一. 問題描述 使用華為推送服務下發IM訊息時,下發訊息請求成功且code碼為80000000,但是手機總是收不到訊息; 在華為推送自助分析(Beta)平臺查看發現,訊息發送觸發了頻控。 二. 問題原因及背景 2023年1月05日起,華為推送服務對咨詢營銷類訊息做了單個設備每日推送數量上限管理,具體 ......

    uj5u.com 2023-04-20 08:40:11 more
  • 歡迎頁輪播影片

    如圖,引導開始,球從上落下,同時淡入文字,然后文字開始輪播,最后一頁時停止,點擊進入首頁。 在來看看效果圖。 重力球先不講,主要歡迎輪播簡單實作 首先新建一個類 TextTranslationXGuideView,用于影片展示 文本是類似的,最后會有個圖片箭頭影片,布局很簡單,就是一個 TextVi ......

    uj5u.com 2023-04-20 08:39:36 more
  • 【FAQ】關于華為推送服務因營銷訊息頻次管控導致服務通訊類訊息

    一. 問題描述 使用華為推送服務下發IM訊息時,下發訊息請求成功且code碼為80000000,但是手機總是收不到訊息; 在華為推送自助分析(Beta)平臺查看發現,訊息發送觸發了頻控。 二. 問題原因及背景 2023年1月05日起,華為推送服務對咨詢營銷類訊息做了單個設備每日推送數量上限管理,具體 ......

    uj5u.com 2023-04-20 08:39:13 more
  • iOS從UI記憶體地址到讀取成員變數(oc/swift)

    開發除錯時,我們發現bug時常首先是從UI顯示發現例外,下一步才會去定位UI相關連的資料的。XCode有給我們提供一系列debug工具,但是很多人可能還沒有形成一套穩定的除錯流程,因此本文嘗試解決這個問題,順便提出一個暴論:UI顯示例外問題只需要兩個步驟就能完成定位作業的80%: 定位例外 UI 組 ......

    uj5u.com 2023-04-19 09:16:23 more
  • FIDE重磅更新!性能飛躍!體驗有禮!

    FIDE 開發者工具重構升級啦!實作500%性能提升,誠邀體驗! 一直以來不少開發者朋友在社區反饋,在使用 FIDE 工具的程序中,時常會遇到諸如加載不及時、代碼預覽/渲染性能不如意的情況,十分影響開發體驗。 作為技術團隊,我們深知一件趁手的開發工具對開發者的重要性,因此,在2023年開年,FinC ......

    uj5u.com 2023-04-19 09:16:15 more
  • 游戲內嵌社區服務開放,助力開發者提升玩家互動與留存

    華為 HMS Core 游戲內嵌社區服務提供快速訪問華為游戲中心論壇能力,支持玩家直接在游戲內瀏覽帖子和交流互動,助力開發者擴展內容生產和觸達的場景。 一、為什么要游戲內嵌社區? 二、游戲內嵌社區的典型使用場景 1、游戲內打開論壇 您可以在游戲內繪制論壇入口,為玩家提供沉浸式發帖、瀏覽、點贊、回帖、 ......

    uj5u.com 2023-04-19 09:15:46 more
  • iOS從UI記憶體地址到讀取成員變數(oc/swift)

    開發除錯時,我們發現bug時常首先是從UI顯示發現例外,下一步才會去定位UI相關連的資料的。XCode有給我們提供一系列debug工具,但是很多人可能還沒有形成一套穩定的除錯流程,因此本文嘗試解決這個問題,順便提出一個暴論:UI顯示例外問題只需要兩個步驟就能完成定位作業的80%: 定位例外 UI 組 ......

    uj5u.com 2023-04-19 09:14:53 more
  • FIDE重磅更新!性能飛躍!體驗有禮!

    FIDE 開發者工具重構升級啦!實作500%性能提升,誠邀體驗! 一直以來不少開發者朋友在社區反饋,在使用 FIDE 工具的程序中,時常會遇到諸如加載不及時、代碼預覽/渲染性能不如意的情況,十分影響開發體驗。 作為技術團隊,我們深知一件趁手的開發工具對開發者的重要性,因此,在2023年開年,FinC ......

    uj5u.com 2023-04-19 09:14:08 more
  • 游戲內嵌社區服務開放,助力開發者提升玩家互動與留存

    華為 HMS Core 游戲內嵌社區服務提供快速訪問華為游戲中心論壇能力,支持玩家直接在游戲內瀏覽帖子和交流互動,助力開發者擴展內容生產和觸達的場景。 一、為什么要游戲內嵌社區? 二、游戲內嵌社區的典型使用場景 1、游戲內打開論壇 您可以在游戲內繪制論壇入口,為玩家提供沉浸式發帖、瀏覽、點贊、回帖、 ......

    uj5u.com 2023-04-19 09:08:34 more