主頁 > 區塊鏈 > 【原創】python實作BP神經網路識別Mnist資料集

【原創】python實作BP神經網路識別Mnist資料集

2020-11-17 18:55:20 區塊鏈

著作權宣告:本文為博主ExcelMann的原創文章,未經博主允許不得轉載,

python實作BP神經網路識別Mnist資料集

作者:ExcelMann,轉載需注明,

話不多說,直接貼代碼,代碼有注釋,

# Author:Xuangan, Xu
# Data:2020-10-28

"""
BP神經網路
-----------------
利用梯度下降法,實作MNIST手寫體數字識別
資料集:Mnist資料集
"""

import os
import struct
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

def load_mnist(path, kind='train'): # kind默認引數值為'train'
    """
        從指定的路徑path中讀取資料
        :param path:為檔案路徑
        :param kind:為檔案的型別(train/t10k)
        :return images為nxm維的陣列,n為樣本個數,m為樣本的特征數,也即為像素個數;
                labels為images對應的標簽;
    """
    labels_path = os.path.join(path,
                               '%s-labels.idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images.idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

def sigmoid(x):
    """
    sigmoid函式
    :param x: 輸入值
    :return: 回傳激活函式的值
    """
    return 1.0/(1.0+np.exp(-x))

# 定義神經網路類
class neuralNetwork:

    def __init__(self,inputNodes,hiddenNodes,outputNodes,learningRate):
        """
        :param inputNodes:輸入層節點個數
        :param hiddenNodes:隱藏層結點個數
        :param outputNodes:輸出層結點個數
        :param learningRate:學習率
        """
        self.iNodes = inputNodes
        self.hNodes = hiddenNodes
        self.oNodes = outputNodes
        self.lr = learningRate
        # 初始化網路權重
        self.w_1 = np.random.uniform(-0.5,0.5,(inputNodes,hiddenNodes))
        self.w_2 = np.random.uniform(-0.5, 0.5, (hiddenNodes,outputNodes))
        # 初始化閾值
        #self.thod_1 = np.random.randn(hiddenNodes)
        self.thod_1 = np.random.uniform(-0.5,0.5,hiddenNodes)
        #self.thod_2 = np.random.randn(outputNodes)
        self.thod_2 = np.random.uniform(-0.5,0.5,outputNodes)

    def culMse(self,pre_y,y):
        """
        計算均方誤差
        :param pre_y: 預測值
        :param y: 期望值
        """
        totalError = 0
        for i in range(len(y)):
            totalError += (y[i]-pre_y[i])**2
        return totalError/2.0

    def culCrossEntropyLoss(self,pre_y,y):
        """
        計算交叉熵損失函式
        :param pre_y: 預測值
        :param y: 期望值
        """
        total_error = 0
        for j in range(len(y)):
            total_error += y[j]*math.log(pre_y[j])
        return (-1)*total_error

    def forward(self,input_data):
        """
        前向傳播
        :param input_data:輸入資料(1X784的一維陣列)
        :return: 回傳輸出層的資料
        """
        # 計算隱含層的輸入值以及輸出值(用到了sigmoid激活函式),結果為大小15的陣列
        hidden_input = input_data.dot(self.w_1)
        hidden_output = sigmoid(hidden_input-self.thod_1)

        # 計算輸出層的輸入值以及輸出值(用到了sigmoid激活函式),結果為大小10的陣列
        final_input = hidden_output.dot(self.w_2)
        final_output = sigmoid(final_input-self.thod_2)
        return final_output,hidden_output

    def backward(self,target,input_data,hidden_output,final_output):
        """
        反向傳播演算法
        """
        g = np.zeros(self.oNodes)  # 第j個輸出層結點對應的廣義偏差
        e = np.zeros(self.hNodes)  # 第h個隱藏層結點對應的廣義偏差
        # 更新隱藏層與輸出層之間的權重w_2
        for h in range(self.hNodes):
            for j in range(self.oNodes):
                # 計算第j個輸出層結點對應的廣義偏差
                g[j] = (target[j]-final_output[j])*final_output[j]*(1-final_output[j])
                # 計算w_hj的權重梯度
                gradient_w_hj = self.lr*g[j]*hidden_output[h]
                # 梯度下降法更新權重引數值
                self.w_2[h][j] = self.w_2[h][j]+gradient_w_hj

        # 更新輸出層的閾值
        for j in range(self.oNodes):
            # 計算第j個輸出層結點的閾值梯度
            gradient_thod_j = (-1) * self.lr * g[j]
            # 梯度下降法更新閾值引數
            self.thod_2[j] = self.thod_2[j] + gradient_thod_j

        # 求第h個隱藏層結點對應的廣義偏差
        for h in range(self.hNodes):
            totalBackValue = 0
            for j in range(self.oNodes):
                totalBackValue += self.w_2[h][j]*g[j]
            e[h] = hidden_output[h]*(1-hidden_output[h])*totalBackValue

        # 更新輸入層與隱藏層之間的權重w_1
        for i in range(self.iNodes):
            for h in range(self.hNodes):
                # 計算w_ih的權重梯度
                gradient_w_ih = self.lr*e[h]*input_data[i]
                # 梯度下降法更新權重引數值
                self.w_1[i][h] = self.w_1[i][h]+gradient_w_ih

        # 更新隱藏層的閾值
        for h in range(self.hNodes):
            # 計算第h個隱藏層結點的閾值梯度
            gradient_thod_h = (-1)*self.lr*e[h]
            # 梯度下降法更新閾值引數
            self.thod_1[h] += gradient_thod_h

    def train(self,input_data,target):
        """
        訓練網路引數
        :param input_data:輸入資料(1X784的一維陣列)
        :param target:標簽陣列(1X10的一維陣列)
        """
        final_output,hidden_output = self.forward(input_data)

        self.backward(target,input_data,hidden_output,final_output)

        return final_output


    def estimate(self,test_data,test_label):
        """
        預測結果
        :param test_data: 輸入資料,nX784維,n為輸入資料個數
        :param test_label: 測驗資料的標簽值
        :return: 回傳準確率
        """
        correct_num = 0 # 預測正確個數
        for i in range(test_data.shape[0]):
            # 計算得到預測結果,preV為網路模型輸出值
            preV,hiddenV = self.forward(test_data[i])
            pre_y = np.argmax(preV)  # 最大可能性的即為預測的值
            label = np.argmax(test_label[i])
            # 預測結果與標簽值對比,計算準確率
            if(pre_y == label):
                correct_num += 1
        return correct_num/test_data.shape[0]

    def SGD(self,train_data,train_label):
        # 定義迭代次數epochs,并執行訓練程序
        epochs = 200
        # 批處理的量大小
        batch_size = 200
        for e in range(epochs):
            # 從樣本中隨機挑選出100個樣本作為訓練集
            batch_mask = np.random.choice(train_data.shape[0], batch_size)
            batch_data = train_data[batch_mask]
            batch_label = train_label[batch_mask]
            # 遍歷批處理樣本
            for i, data in enumerate(batch_data):
                # 執行模型訓練
                final_output = self.train(data, batch_label[i])
                if (i % 40 == 0):
                    # 計算loss
                    mse = self.culMse(final_output, batch_label[i])
                    print(f'epoch:{e},i:{i},loss:{mse}')

if __name__ == "__main__":
    # 通過tensorflow讀取mnist資料,并對讀取到的資料進行處理
    mnist = tf.keras.datasets.mnist
    (train_x,train_y),(test_x,test_y) = mnist.load_data()
    # 創建以下陣列,用于存盤處理后的訓練和測驗資料
    train_data = np.zeros((60000,784))
    train_label = np.zeros((60000,10))
    test_data = np.zeros((10000,784))
    test_label = np.zeros((10000,10))
    # 處理資料,使得影像資料的值范圍為0-1,并將標簽改為one-hot型別
    for i in range(60000):  # 處理訓練資料
        train_data[i] = (np.array(train_x[i]).flatten())/255
        temp = np.zeros(10)
        temp[train_y[i]] = 1
        train_label[i] = temp
    for i in range(10000):  # 處理測驗資料
        test_data[i] = (np.array(test_x[i]).flatten())/255
        temp = np.zeros(10)
        temp[test_y[i]] = 1
        test_label[i] = temp

    # 初始化神經網路結點個數和學習率
    input_nodes = 784
    hidden_nodes = 15
    output_nodes = 10
    learningRate = 0.15
    # 創建神經網路物件network
    network = neuralNetwork(input_nodes,hidden_nodes,output_nodes,learningRate)
    # 執行隨機梯度下降演算法
    network.SGD(train_data,train_label)
    
    # 測驗階段,輸出精確率
    accuracy = network.estimate(test_data,test_label)
    print(f'test_data_Accuracy:{accuracy}')

轉載請註明出處,本文鏈接:https://www.uj5u.com/qukuanlian/222541.html

標籤:區塊鏈

上一篇:表面模糊濾鏡

下一篇:【MATLAB-app】appdesigner 設計中的幾個神招數(精心打造,附源代碼)

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • JAVA使用 web3j 進行token轉賬

    最近新學習了下區塊鏈這方面的知識,所學不多,給大家分享下。 # 1. 關于web3j web3j是一個高度模塊化,反應性,型別安全的Java和Android庫,用于與智能合約配合并與以太坊網路上的客戶端(節點)集成。 # 2. 準備作業 jdk版本1.8 引入maven <dependency> < ......

    uj5u.com 2020-09-10 03:03:06 more
  • 以太坊智能合約開發框架Truffle

    前言 部署智能合約有多種方式,命令列的瀏覽器的渠道都有,但往往跟我們程式員的風格不太相符,因為我們習慣了在IDE里寫了代碼然后打包運行看效果。 雖然現在IDE中已經存在了Solidity插件,可以撰寫智能合約,但是部署智能合約卻要另走他路,沒辦法進行一個快捷的部署與測驗。 如果團隊管理的區塊節點多、 ......

    uj5u.com 2020-09-10 03:03:12 more
  • 谷歌二次驗證碼成為區塊鏈專用安全碼,你怎么看?

    前言 谷歌身份驗證器,前些年大家都比較陌生,但隨著國內互聯網安全的加強,它越來越多地出現在大家的視野中。 比較廣泛接觸的人群是國際3A游戲愛好者,游戲盜號現象嚴重+國外賬號安全應用廣泛,這類游戲一般都會要求用戶系結名為“兩步驗證”、“雙重驗證”等,平臺一般都推薦用谷歌身份驗證器。 后來區塊鏈業務風靡 ......

    uj5u.com 2020-09-10 03:03:17 more
  • 密碼學DAY1

    目錄 ##1.1 密碼學基本概念 密碼在我們的生活中有著重要的作用,那么密碼究竟來自何方,為何會產生呢? 密碼學是網路安全、資訊安全、區塊鏈等產品的基礎,常見的非對稱加密、對稱加密、散列函式等,都屬于密碼學范疇。 密碼學有數千年的歷史,從最開始的替換法到如今的非對稱加密演算法,經歷了古典密碼學,近代密 ......

    uj5u.com 2020-09-10 03:03:50 more
  • 密碼學DAY1_02

    目錄 ##1.1 ASCII編碼 ASCII(American Standard Code for Information Interchange,美國資訊交換標準代碼)是基于拉丁字母的一套電腦編碼系統,主要用于顯示現代英語和其他西歐語言。它是現今最通用的單位元組編碼系統,并等同于國際標準ISO/IE ......

    uj5u.com 2020-09-10 03:04:50 more
  • 密碼學DAY2

    ##1.1 加密模式 加密模式:https://docs.oracle.com/javase/8/docs/api/javax/crypto/Cipher.html ECB ECB : Electronic codebook, 電子密碼本. 需要加密的訊息按照塊密碼的塊大小被分為數個塊,并對每個塊進 ......

    uj5u.com 2020-09-10 03:05:42 more
  • NTP時鐘服務器的特點(京準電子)

    NTP時鐘服務器的特點(京準電子) NTP時鐘服務器的特點(京準電子) 京準電子官V——ahjzsz 首先對時間同步進行了背景介紹,然后討論了不同的時間同步網路技術,最后指出了建立全球或區域時間同步網存在的問題。 一、概 述 在通信領域,“同步”概念是指頻率的同步,即網路各個節點的時鐘頻率和相位同步 ......

    uj5u.com 2020-09-10 03:05:47 more
  • 標準化考場時鐘同步系統推進智能化校園建設

    標準化考場時鐘同步系統推進智能化校園建設 標準化考場時鐘同步系統推進智能化校園建設 安徽京準電子科技官微——ahjzsz 一、背景概述隨著教育事業的快速發展,學校建設如雨后春筍,隨之而來的學校教育、管理、安全方面的問題成了學校管理人員面臨的最大的挑戰,這些問題同時也是學生家長所擔心的。為了讓學生有更 ......

    uj5u.com 2020-09-10 03:05:51 more
  • 位元幣入門

    引言 位元幣基本結構 位元幣基礎知識 1)哈希演算法 2)非對稱加密技術 3)數字簽名 4)MerkleTree 5)哪有位元幣,有的是UTXO 6)位元幣挖礦與共識 7)區塊驗證(共識) 總結 引言 上一篇我們已經知道了什么是區塊鏈,此篇說一下區塊鏈的第一個應用——位元幣。其實先有位元幣,后有的區塊 ......

    uj5u.com 2020-09-10 03:06:15 more
  • 北斗對時服務器(北斗對時設備)電力系統應用

    北斗對時服務器(北斗對時設備)電力系統應用 北斗對時服務器(北斗對時設備)電力系統應用 京準電子科技官微(ahjzsz) 中國北斗衛星導航系統(英文名稱:BeiDou Navigation Satellite System,簡稱BDS),因為是目前世界范圍內唯一可以大面積提供免費定位服務的系統,所以 ......

    uj5u.com 2020-09-10 03:06:20 more
最新发布
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:46:47 more
  • Hyperledger Fabric 使用 CouchDB 和復雜智能合約開發

    在上個實驗中,我們已經實作了簡單智能合約實作及客戶端開發,但該實驗中智能合約只有基礎的增刪改查功能,且其中的資料管理功能與傳統 MySQL 比相差甚遠。本文將在前面實驗的基礎上,將 Hyperledger Fabric 的默認資料庫支持 LevelDB 改為 CouchDB 模式,以實作更復雜的資料... ......

    uj5u.com 2023-04-16 07:28:31 more
  • .NET Core 波場鏈離線簽名、廣播交易(發送 TRX和USDT)筆記

    Get Started NuGet You can run the following command to install the Tron.Wallet.Net in your project. PM> Install-Package Tron.Wallet.Net 配置 public reco ......

    uj5u.com 2023-04-14 08:08:00 more
  • DKP 黑客分析——不正確的代幣對比率計算

    概述: 2023 年 2 月 8 日,針對 DKP 協議的閃電貸攻擊導致該協議的用戶損失了 8 萬美元,因為 execute() 函式取決于 USDT-DKP 對中兩種代幣的余額比率。 智能合約黑客概述: 攻擊者的交易:0x0c850f,0x2d31 攻擊者地址:0xF38 利用合同:0xf34ad ......

    uj5u.com 2023-04-07 07:46:09 more
  • Defi開發簡介

    Defi開發簡介 介紹 Defi是去中心化金融的縮寫, 是一項旨在利用區塊鏈技術和智能合約創建更加開放,可訪問和透明的金融體系的運動. 這與傳統金融形成鮮明對比,傳統金融通常由少數大型銀行和金融機構控制 在Defi的世界里,用戶可以直接從他們的電腦或移動設備上訪問廣泛的金融服務,而不需要像銀行或者信 ......

    uj5u.com 2023-04-05 08:01:34 more
  • solidity簡單的ERC20代幣實作

    // SPDX-License-Identifier: GPL-3.0 pragma solidity >=0.7.0 <0.9.0; import "hardhat/console.sol"; //ERC20 同質化代幣,每個代幣的本質或性質都是相同 //ETH 是原生代幣,它不是ERC20代幣, ......

    uj5u.com 2023-03-21 07:56:29 more
  • solidity 參考型別修飾符memory、calldata與storage 常量修飾符C

    在solidity語言中 參考型別修飾符(參考型別為存盤空間不固定的數值型別) memory、calldata與storage,它們只能修飾參考型別變數,比如字串、陣列、位元組等... memory 適用于方法傳參、返參或在方法體內使用,使用完就會清除掉,釋放記憶體 calldata 僅適用于方法傳參 ......

    uj5u.com 2023-03-08 07:57:54 more
  • solidity注解標簽

    在solidity語言中 注釋符為// 注解符為/* 內容*/ 或者 是 ///內容 注解中含有這幾個標簽給予我們使用 @title 一個應該描述合約/介面的標題 contract, library, interface @author 作者的名字 contract, library, interf ......

    uj5u.com 2023-03-08 07:57:49 more
  • 評價指標:相似度、GAS消耗

    【代碼注釋自動生成方法綜述】 這些評測指標主要來自機器翻譯和文本總結等研究領域,可以評估候選文本(即基于代碼注釋自動方法而生成)和參考文本(即基于手工方式而生成)的相似度. BLEU指標^[^?88^^?^]^:其全稱是bilingual evaluation understudy.該指標是最早用于 ......

    uj5u.com 2023-02-23 07:27:39 more
  • 基于NOSTR協議的“公有制”版本的Twitter,去中心化社交軟體Damus

    最近,一個幽靈,Web3的幽靈,在網路游蕩,它叫Damus,這玩意詮釋了什么叫做病毒式營銷,滑稽的是,一個Web3產品卻在Web2的產品鏈上瘋狂傳銷,各方大佬紛紛為其背書,到底發生了什么?Damus的葫蘆里,賣的是什么藥? 注冊和簡單實用 很少有什么產品在用戶注冊環節會有什么噱頭,但Damus確實出 ......

    uj5u.com 2023-02-05 06:48:39 more