主頁 > 區塊鏈 > PyTorch_簡單神經網路搭建_MNIST資料集

PyTorch_簡單神經網路搭建_MNIST資料集

2020-10-26 23:30:23 區塊鏈

今天用PyTorch參考《Python深度學習基于PyTorch》搭建了一個簡單的神經網路,在這里做一下筆記,

首先附上PyTorch中文檔案鏈接,下面的各介面函式在這里面基本都能查到,寶藏檔案,對于像我一樣的新手菜鳥特別友好,強推!!!
PyTorch中文檔案鏈接

正文開始:
這是本次搭建神經網路的結構圖

此網路包含兩個隱藏層,激活函式都為relu函式,最后用torch.max(out,1)找出張量out最大值索引作為預測值,

下面不廢話了,直接代碼實作

1. 先匯入必要的模塊

import numpy as np
import torch

#匯入PyTorch內置的mnist資料
from torchvision.datasets import mnist

#匯入預處理模塊
from torchvision import transforms
from torch.utils.data import DataLoader

#匯入神經網路工具
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader:
該介面主要用來將自定義的資料讀取介面的輸出或者PyTorch已有的資料讀取介面的輸入按照batch size封裝成Tensor,

2. 定義超引數

#定義后面要用到的超引數
train_batch_size = 64
test_batch_size = 128

#學習率與訓練次數
learning_rate = 0.01
nums_epoches = 20

#優化器的時候使用的引數
lr = 0.1
momentum = 0.5

batch_size:相當于每次匯入訓練的樣本量大小(相比較于一次匯入完,一次匯入一張,需要設定一個合適的量)一般高級演算法要注意設定量,簡單神經網路不用太過在意,

3.下載資料并對資料進行預處理

#用compose來定意預處理函式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])

#下載資料,在工程檔案夾里新建一個data檔案夾儲存下載的資料
train_dataset = mnist.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform, target_transform=None, download=False)

#資料加載器,組合資料集和采樣器,并在資料集上提供單行程或多行程迭代器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

把預處理需要用到的東西組合在Compose里面,
transforms.ToTensor()是把一個取值范圍是[0,255]的PIL.Image或者shape為(H,W,C)的numpy.ndarray,轉換成形狀為[C,H,W],取值范圍是[0,1.0]的torch.FloadTensor,(這句話我的理解是把資料格式轉換成網路里可以使用的資料格式)
transforms.Normalize則是將灰度影像正則化,

4.可視化資料

import matplotlib.pyplot as plt
%matplotlib inline
examples = enumerate(test_loader)
batch_idx,(example_data,example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0],cmap='gray',interpolation='none')
    plt.title("Ground Truth:{}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])

資料可視化這部分我沒有仔細看,直接把代碼列出,以后如果之后再用到的時候我再單獨寫一個筆記,
下面是可視化后的結果:
這是可視化后的mnist資料

5.構建模型

class CNN(nn.Module):
    def __init__(self,in_dim,hidden_1,hidden_2,out_dim):
        super(CNN,self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, hidden_1, bias=True),nn.BatchNorm1d(hidden_1))
        self.layer2 = nn.Sequential(nn.Linear(hidden_1,hidden_2,bias=True),nn.BatchNorm1d(hidden_2))
        self.layer3 = nn.Sequential(nn.Linear(hidden_2,out_dim))
        
    def forward(self,x):
    	#注意 F 與 nn 下的激活函式使用起來不一樣的
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x

class torch.nn.Sequential(* args):一個時序容器,Modules 會以他們傳入的順序被添加到容器中,當然,也可以傳入一個,

**注:**我敲的時候還不知道CNN 跟簡單神經網路的區別,所以把這個類名定義為CNN了,大家在實作的時候可以定義為Net,

6.實體化網路

#實體化網路,只考慮使用CPU
model = CNN(28*28,300,100,10)
#定義損失函式和優化器
criterion = nn.CrossEntropyLoss()
#momentum:動量因子有什么用處?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)

class torch.nn.CrossEntropyLoss(weight=None, size_average=True):此標準將LogSoftMax和NLLLoss集成到一個類中,當訓練一個多類分類器的時候,這個方法是十分有用的,
動量因子的作用后面會了我再來修改!

7.訓練模型

#開始訓練 先定義存盤損失函式和準確率的陣列
losses = []
acces = []
#測驗用
eval_losses = []
eval_acces = []

for epoch in range(nums_epoches):
    #每次訓練先清零
    train_loss = 0
    train_acc = 0
    #將模型設定為訓練模式
    model.train()
    #動態學習率
    if epoch%5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.1
    for img,label in train_loader:
        #例如 img=[64,1,28,28] 做完view()后變為[64,1*28*28]
        #把圖片資料格式轉換成與網路匹配的格式
        img = img.view(img.size(0),-1)
        #前向傳播,將圖片資料傳入模型中
        out = model(img)
        loss = criterion(out,label)
        #反向傳播
        #optimizer.zero_grad()意思是把梯度置零,也就是把loss關于weight的導數變成0
        optimizer.zero_grad()
        loss.backward()
        #這個方法會更新所有的引數,一旦梯度被如backward()之類的函式計算好后,我們就可以呼叫這個函式
        optimizer.step()
        
        #記錄誤差 
        train_loss += loss.item()
        
        #計算分類的準確率,找到概率最大的下標
        _,pred = out.max(1)
        num_correct = (pred == label).sum().item()#記錄標簽正確的個數
        acc = num_correct/img.shape[0]
        train_acc += acc
    losses.append(train_loss/len(train_loader))
    acces.append(train_acc/len(train_loader))
    
    eval_loss = 0
    eval_acc = 0
    model.eval()
    for img,label in test_loader:
        img = img.view(img.size(0),-1)
        
        out = model(img)
        loss = criterion(out,label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        eval_loss += loss.item()
        
        _,pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct/img.shape[0]
        eval_acc += acc
    eval_losses.append(eval_loss/len(test_loader))
    eval_acces.append(eval_acc/len(test_loader))
    

    print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'
             .format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),
                    eval_loss/len(test_loader),eval_acc/len(test_loader)))

這是訓練后的輸出:
epoch:0,Train Loss:0.3494,Train Acc:0.9190,Test Loss:0.1510,Test Acc:0.9550
epoch:1,Train Loss:0.1290,Train Acc:0.9644,Test Loss:0.1037,Test Acc:0.9687
epoch:2,Train Loss:0.0882,Train Acc:0.9756,Test Loss:0.0848,Test Acc:0.9744
epoch:3,Train Loss:0.0676,Train Acc:0.9818,Test Loss:0.0686,Test Acc:0.9778
epoch:4,Train Loss:0.0535,Train Acc:0.9853,Test Loss:0.0569,Test Acc:0.9824
epoch:5,Train Loss:0.0385,Train Acc:0.9906,Test Loss:0.0308,Test Acc:0.9906
epoch:6,Train Loss:0.0345,Train Acc:0.9920,Test Loss:0.0306,Test Acc:0.9911
epoch:7,Train Loss:0.0321,Train Acc:0.9930,Test Loss:0.0301,Test Acc:0.9916
epoch:8,Train Loss:0.0324,Train Acc:0.9931,Test Loss:0.0293,Test Acc:0.9919
epoch:9,Train Loss:0.0304,Train Acc:0.9937,Test Loss:0.0288,Test Acc:0.9921
epoch:10,Train Loss:0.0302,Train Acc:0.9935,Test Loss:0.0282,Test Acc:0.9925
epoch:11,Train Loss:0.0294,Train Acc:0.9937,Test Loss:0.0274,Test Acc:0.9929
epoch:12,Train Loss:0.0289,Train Acc:0.9938,Test Loss:0.0274,Test Acc:0.9931
epoch:13,Train Loss:0.0294,Train Acc:0.9941,Test Loss:0.0274,Test Acc:0.9930
epoch:14,Train Loss:0.0286,Train Acc:0.9944,Test Loss:0.0280,Test Acc:0.9925
epoch:15,Train Loss:0.0289,Train Acc:0.9939,Test Loss:0.0279,Test Acc:0.9924
epoch:16,Train Loss:0.0287,Train Acc:0.9939,Test Loss:0.0277,Test Acc:0.9925
epoch:17,Train Loss:0.0290,Train Acc:0.9937,Test Loss:0.0272,Test Acc:0.9929
epoch:18,Train Loss:0.0295,Train Acc:0.9938,Test Loss:0.0277,Test Acc:0.9924
epoch:19,Train Loss:0.0285,Train Acc:0.9942,Test Loss:0.0275,Test Acc:0.9932

8.可視化訓練及測驗損失值

plt.title('trainloss')
plt.plot(np.arange(len(losses)),losses)
plt.legend(['Train Loss'],loc='upper right')

可視化訓練次數于損失函式值的關系:
在這里插入圖片描述
end
第一次在CSDN上寫筆記,希望可以堅持,慢慢成長,

轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/192729.html

標籤:區塊鏈

上一篇:無人帆船模擬及實船實驗步驟

下一篇:OpenCV:使用python-cv2實作Harr+Adaboost人臉識別

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • JAVA使用 web3j 進行token轉賬

    最近新學習了下區塊鏈這方面的知識,所學不多,給大家分享下。 # 1. 關于web3j web3j是一個高度模塊化,反應性,型別安全的Java和Android庫,用于與智能合約配合并與以太坊網路上的客戶端(節點)集成。 # 2. 準備作業 jdk版本1.8 引入maven <dependency> < ......

    uj5u.com 2020-09-10 03:03:06 more
  • 以太坊智能合約開發框架Truffle

    前言 部署智能合約有多種方式,命令列的瀏覽器的渠道都有,但往往跟我們程式員的風格不太相符,因為我們習慣了在IDE里寫了代碼然后打包運行看效果。 雖然現在IDE中已經存在了Solidity插件,可以撰寫智能合約,但是部署智能合約卻要另走他路,沒辦法進行一個快捷的部署與測驗。 如果團隊管理的區塊節點多、 ......

    uj5u.com 2020-09-10 03:03:12 more
  • 谷歌二次驗證碼成為區塊鏈專用安全碼,你怎么看?

    前言 谷歌身份驗證器,前些年大家都比較陌生,但隨著國內互聯網安全的加強,它越來越多地出現在大家的視野中。 比較廣泛接觸的人群是國際3A游戲愛好者,游戲盜號現象嚴重+國外賬號安全應用廣泛,這類游戲一般都會要求用戶系結名為“兩步驗證”、“雙重驗證”等,平臺一般都推薦用谷歌身份驗證器。 后來區塊鏈業務風靡 ......

    uj5u.com 2020-09-10 03:03:17 more
  • 密碼學DAY1

    目錄 ##1.1 密碼學基本概念 密碼在我們的生活中有著重要的作用,那么密碼究竟來自何方,為何會產生呢? 密碼學是網路安全、資訊安全、區塊鏈等產品的基礎,常見的非對稱加密、對稱加密、散列函式等,都屬于密碼學范疇。 密碼學有數千年的歷史,從最開始的替換法到如今的非對稱加密演算法,經歷了古典密碼學,近代密 ......

    uj5u.com 2020-09-10 03:03:50 more
  • 密碼學DAY1_02

    目錄 ##1.1 ASCII編碼 ASCII(American Standard Code for Information Interchange,美國資訊交換標準代碼)是基于拉丁字母的一套電腦編碼系統,主要用于顯示現代英語和其他西歐語言。它是現今最通用的單位元組編碼系統,并等同于國際標準ISO/IE ......

    uj5u.com 2020-09-10 03:04:50 more
  • 密碼學DAY2

    ##1.1 加密模式 加密模式:https://docs.oracle.com/javase/8/docs/api/javax/crypto/Cipher.html ECB ECB : Electronic codebook, 電子密碼本. 需要加密的訊息按照塊密碼的塊大小被分為數個塊,并對每個塊進 ......

    uj5u.com 2020-09-10 03:05:42 more
  • NTP時鐘服務器的特點(京準電子)

    NTP時鐘服務器的特點(京準電子) NTP時鐘服務器的特點(京準電子) 京準電子官V——ahjzsz 首先對時間同步進行了背景介紹,然后討論了不同的時間同步網路技術,最后指出了建立全球或區域時間同步網存在的問題。 一、概 述 在通信領域,“同步”概念是指頻率的同步,即網路各個節點的時鐘頻率和相位同步 ......

    uj5u.com 2020-09-10 03:05:47 more
  • 標準化考場時鐘同步系統推進智能化校園建設

    標準化考場時鐘同步系統推進智能化校園建設 標準化考場時鐘同步系統推進智能化校園建設 安徽京準電子科技官微——ahjzsz 一、背景概述隨著教育事業的快速發展,學校建設如雨后春筍,隨之而來的學校教育、管理、安全方面的問題成了學校管理人員面臨的最大的挑戰,這些問題同時也是學生家長所擔心的。為了讓學生有更 ......

    uj5u.com 2020-09-10 03:05:51 more
  • 位元幣入門

    引言 位元幣基本結構 位元幣基礎知識 1)哈希演算法 2)非對稱加密技術 3)數字簽名 4)MerkleTree 5)哪有位元幣,有的是UTXO 6)位元幣挖礦與共識 7)區塊驗證(共識) 總結 引言 上一篇我們已經知道了什么是區塊鏈,此篇說一下區塊鏈的第一個應用——位元幣。其實先有位元幣,后有的區塊 ......

    uj5u.com 2020-09-10 03:06:15 more
  • 北斗對時服務器(北斗對時設備)電力系統應用

    北斗對時服務器(北斗對時設備)電力系統應用 北斗對時服務器(北斗對時設備)電力系統應用 京準電子科技官微(ahjzsz) 中國北斗衛星導航系統(英文名稱:BeiDou Navigation Satellite System,簡稱BDS),因為是目前世界范圍內唯一可以大面積提供免費定位服務的系統,所以 ......

    uj5u.com 2020-09-10 03:06:20 more
最新发布
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:46:47 more
  • Hyperledger Fabric 使用 CouchDB 和復雜智能合約開發

    在上個實驗中,我們已經實作了簡單智能合約實作及客戶端開發,但該實驗中智能合約只有基礎的增刪改查功能,且其中的資料管理功能與傳統 MySQL 比相差甚遠。本文將在前面實驗的基礎上,將 Hyperledger Fabric 的默認資料庫支持 LevelDB 改為 CouchDB 模式,以實作更復雜的資料... ......

    uj5u.com 2023-04-16 07:28:31 more
  • .NET Core 波場鏈離線簽名、廣播交易(發送 TRX和USDT)筆記

    Get Started NuGet You can run the following command to install the Tron.Wallet.Net in your project. PM> Install-Package Tron.Wallet.Net 配置 public reco ......

    uj5u.com 2023-04-14 08:08:00 more
  • DKP 黑客分析——不正確的代幣對比率計算

    概述: 2023 年 2 月 8 日,針對 DKP 協議的閃電貸攻擊導致該協議的用戶損失了 8 萬美元,因為 execute() 函式取決于 USDT-DKP 對中兩種代幣的余額比率。 智能合約黑客概述: 攻擊者的交易:0x0c850f,0x2d31 攻擊者地址:0xF38 利用合同:0xf34ad ......

    uj5u.com 2023-04-07 07:46:09 more
  • Defi開發簡介

    Defi開發簡介 介紹 Defi是去中心化金融的縮寫, 是一項旨在利用區塊鏈技術和智能合約創建更加開放,可訪問和透明的金融體系的運動. 這與傳統金融形成鮮明對比,傳統金融通常由少數大型銀行和金融機構控制 在Defi的世界里,用戶可以直接從他們的電腦或移動設備上訪問廣泛的金融服務,而不需要像銀行或者信 ......

    uj5u.com 2023-04-05 08:01:34 more
  • solidity簡單的ERC20代幣實作

    // SPDX-License-Identifier: GPL-3.0 pragma solidity >=0.7.0 <0.9.0; import "hardhat/console.sol"; //ERC20 同質化代幣,每個代幣的本質或性質都是相同 //ETH 是原生代幣,它不是ERC20代幣, ......

    uj5u.com 2023-03-21 07:56:29 more
  • solidity 參考型別修飾符memory、calldata與storage 常量修飾符C

    在solidity語言中 參考型別修飾符(參考型別為存盤空間不固定的數值型別) memory、calldata與storage,它們只能修飾參考型別變數,比如字串、陣列、位元組等... memory 適用于方法傳參、返參或在方法體內使用,使用完就會清除掉,釋放記憶體 calldata 僅適用于方法傳參 ......

    uj5u.com 2023-03-08 07:57:54 more
  • solidity注解標簽

    在solidity語言中 注釋符為// 注解符為/* 內容*/ 或者 是 ///內容 注解中含有這幾個標簽給予我們使用 @title 一個應該描述合約/介面的標題 contract, library, interface @author 作者的名字 contract, library, interf ......

    uj5u.com 2023-03-08 07:57:49 more
  • 評價指標:相似度、GAS消耗

    【代碼注釋自動生成方法綜述】 這些評測指標主要來自機器翻譯和文本總結等研究領域,可以評估候選文本(即基于代碼注釋自動方法而生成)和參考文本(即基于手工方式而生成)的相似度. BLEU指標^[^?88^^?^]^:其全稱是bilingual evaluation understudy.該指標是最早用于 ......

    uj5u.com 2023-02-23 07:27:39 more
  • 基于NOSTR協議的“公有制”版本的Twitter,去中心化社交軟體Damus

    最近,一個幽靈,Web3的幽靈,在網路游蕩,它叫Damus,這玩意詮釋了什么叫做病毒式營銷,滑稽的是,一個Web3產品卻在Web2的產品鏈上瘋狂傳銷,各方大佬紛紛為其背書,到底發生了什么?Damus的葫蘆里,賣的是什么藥? 注冊和簡單實用 很少有什么產品在用戶注冊環節會有什么噱頭,但Damus確實出 ......

    uj5u.com 2023-02-05 06:48:39 more