主頁 >  其他 > 從零實作深度學習框架——手寫前饋網路實作電影評論分類

從零實作深度學習框架——手寫前饋網路實作電影評論分類

2022-02-21 21:44:16 其他

引言

本著“凡我不能創造的,我就不能理解”的思想,本系列文章會基于純Python以及NumPy從零創建自己的深度學習框架,該框架類似PyTorch能實作自動求導,

要深入理解深度學習,從零開始創建的經驗非常重要,從自己可以理解的角度出發,盡量不使用外部完備的框架前提下,實作我們想要的模型,本系列文章的宗旨就是通過這樣的程序,讓大家切實掌握深度學習底層實作,而不是僅做一個調包俠,
本系列文章首發于微信公眾號:JavaNLP

我們已經了解了前饋神經網路的基礎知識,本文就基于前饋網路來解決實際問題——IMDB電影評論分類,

imdb資料集

imdb資料集是英文電影評論資料集,包含50000條兩極分化的評論,資料集被分為25000條用于訓練和25000條用于測驗的評論,它們都包含50%的正面和50%的負面評論,

我們想要訓練一個前饋網路能學到輸入一段英文評論,判斷這段評論是正面(表揚、鼓勵)還是負面(狂噴)的,屬于一個二分類問題,

我們先來看下資料集,為了簡單,我們先用keras提供的封裝方法加載資料集,需要引入from keras.datasets import imdb

def load_dataset():
    # 保留訓練資料中前10000個最常出現的單詞,舍棄低頻單詞
    (X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=10000)

    return Tensor(X_train), Tensor(X_test), Tensor(y_train), Tensor(y_test)


def indices_to_sentence(indices: Tensor):
    # 單詞索引字典 word -> index
    word_index = imdb.get_word_index()
    # 逆單詞索引字典 index -> word
    reverse_word_index = dict(
        [(value, key) for (key, value) in word_index.items()])
    # 將index串列轉換為word串列
    #
    # 0、1、2 是為“padding”(填充)、“start of sequence”(序
    # 列開始)、“unknown”(未知詞)分別保留的索引
    decoded_review = ' '.join(
        [reverse_word_index.get(i - 3, '?') for i in indices.data])
    return decoded_review

if __name__ == '__main__':
    X_train, X_test, y_train, y_test = load_dataset()

    print(indices_to_sentence(X_train[0]))
    print(y_train[0])

這里加載了第一個樣本,將索引還原成句子,最后列印出該句子對應的標簽,

? this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert ? is an amazing actor and now the same being director ? father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for ? and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also ? to the two little boy's that played the ? of norman and paul they were just brilliant children are often left out of the ? list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
Tensor(1.0, requires_grad=False) # 對應的類別

很長的一段評論,該評論被標記為正面(1),

句子的處理

這是我們第一次接觸NLP相關任務,雖然我們人類能很容易地看懂文字,但是讓機器讀懂文字不是一件容易的事,

這里用了最簡單的方法,首先將句子拆分成一個個單詞,然后構造一個詞典來保存每個單詞和其對應的序號(imdb.get_word_index()),這里keras已經幫我們處理好了,

然后保存每個句子的時候,我們只需要保留句子中所有單詞對應的序號串列即可,得到序號串列相當于將句子進行了數字化,只有數字化之后,計算機才能處理,

每個樣本都是單詞序列串列,因此我們需要將它們還原成句子,人類才能看得懂,

但是我們不能將整數序列直接輸入神經網路,我們需要將串列轉換為向量,我們這里對序列串列使用ont-hot編碼,比如序列[3,5]會被轉換為10000維的向量,只有索引3和5的元素是1,相當于標記了哪些單詞出現在序列中,這是一個簡單的句子向量化方法,

這里我們把每個句子轉換成一個10000維的向量,這里的10000是我們設的最常見的單詞數,包括填充詞、序列開始詞和未知詞,每個句子都會有一個序列開始詞,表示這是一個句子的開始單詞;未知詞是來處理不常見單詞的,比如你不在這10000個常見單詞里面的詞;填充詞用于填充句子;

def vectorize_sequences(sequences, dimension=10000):
    # 默認生成一個[句子長度,維度數]的向量
    results = np.zeros((len(sequences), dimension), dtype='uint8')
    for i, sequence in enumerate(sequences):
        # 將第i個序列中,對應單詞序號處的位置置為1
        results[i, sequence] = 1
    return results

X_train = vectorize_sequences(X_train)
print(X_train[0])
[0 1 1 ... 0 0 0]

處理好句子之后,我們就可以將資料輸入到神經網路中,

構建前饋神經網路

輸入資料是向量,而標簽是標量(1或0),這和我們之前使用邏輯回歸構建的模型一樣,不過這次我們采用神經網路的方式,

我們使用前面介紹的單隱藏層前饋網路來處理這個問題,看一下效果如何,

首先設計我們的單隱藏層網路:

class Feedforward(nn.Module):
    '''
    簡單單隱藏層前饋網路,用于分類問題
    '''

    def __init__(self, input_size, hidden_size, output_size):
        '''

        :param input_size: 輸入維度
        :param hidden_size: 隱藏層大小
        :param output_size: 分類個數
        '''
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),  # 隱藏層,將輸入轉換為隱藏向量
            nn.ReLU(),  # 激活函式
            nn.Linear(hidden_size, output_size)  # 輸出層,將隱藏向量轉換為輸出
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

實作這種順序網路很簡單,就像堆疊石頭一樣,一層一層往上堆疊即可,

stack_of_stones

由于我們將使用之前介紹的BCELoss,因此最終的輸出只是logits即可,不需要是經過Sigmoid的概率,

訓練模型

由于我們的資料量足夠大,我們可以從訓練集中保留一部分資料作為驗證集,以監控訓練的效果,

# 保留驗證集
# X_train有25000條資料,我們保留10000條作為驗證集
X_val = X_train[:10000]
X_train = X_train[10000:]

y_val = y_train[:10000]
y_train = y_train[10000:]

下面我們構造模型,并準備優化器和損失器,由于我們加了批處理,這里計算總損失,而不是均值,

model = Feedforward(10000, 128, 1)  # 輸入大小10000,隱藏層大小128,輸出只有一個,代表判斷為正例的概率

optimizer = SGD(model.parameters(), lr=0.001)
# 先計算sum
loss = BCELoss(reduction="sum")

同時由于資料量較大,我們需要進行批處理,將訓練集和驗證集分成每批大小為512的批資料,訓練20輪,

epochs = 20
batch_size = 512 # 批大小
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

# 由于資料過多,需要拆分成批次
X_train_batches, y_train_batches = make_batches(X_train, y_train,batch_size=batch_size)

X_val_batches, y_val_batches = make_batches(X_val, y_val, batch_size=batch_size)

for epoch in range(epochs):
    train_loss, train_accuracy = compute_loss_and_accury(X_train_batches, y_train_batches, model, loss, len(X_train), optimizer)

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    with no_grad():
        val_loss, val_accuracy = compute_loss_and_accury(X_val_batches, y_val_batches, model, loss, len(X_val))

        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        print(f"Epoch:{epoch}, Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}% | "
              f" Validation Loss:{val_loss:.4f} , Accuracy:{val_accuracy:.2f}%")

訓練程序中的列印如下:

Epoch:1, Training Loss: 0.6335, Accuracy: 60.20% |  Validation Loss:0.6429 , Accuracy:49.55%
Epoch:2, Training Loss: 0.6333, Accuracy: 66.75% |  Validation Loss:0.5527 , Accuracy:76.28%
Epoch:3, Training Loss: 0.5587, Accuracy: 75.22% |  Validation Loss:0.4782 , Accuracy:82.10%
Epoch:4, Training Loss: 0.4891, Accuracy: 77.11% |  Validation Loss:0.3758 , Accuracy:84.26%
Epoch:5, Training Loss: 0.5085, Accuracy: 75.09% |  Validation Loss:0.3763 , Accuracy:83.28%
Epoch:6, Training Loss: 0.3887, Accuracy: 82.52% |  Validation Loss:0.3544 , Accuracy:84.76%
Epoch:7, Training Loss: 0.3628, Accuracy: 83.79% |  Validation Loss:0.3584 , Accuracy:85.17%
Epoch:8, Training Loss: 0.3451, Accuracy: 84.71% |  Validation Loss:0.3532 , Accuracy:83.87%
Epoch:9, Training Loss: 0.3201, Accuracy: 85.83% |  Validation Loss:0.3433 , Accuracy:84.42%
Epoch:10, Training Loss: 0.3311, Accuracy: 84.93% |  Validation Loss:0.3058 , Accuracy:87.20%
Epoch:11, Training Loss: 0.2989, Accuracy: 87.04% |  Validation Loss:0.3484 , Accuracy:83.60%
Epoch:12, Training Loss: 0.2685, Accuracy: 88.61% |  Validation Loss:0.2958 , Accuracy:87.65%
Epoch:13, Training Loss: 0.2640, Accuracy: 88.35% |  Validation Loss:0.2957 , Accuracy:87.72%
Epoch:14, Training Loss: 0.2887, Accuracy: 87.17% |  Validation Loss:0.3808 , Accuracy:82.40%
Epoch:15, Training Loss: 0.3235, Accuracy: 85.68% |  Validation Loss:0.2926 , Accuracy:87.75%
Epoch:16, Training Loss: 0.2650, Accuracy: 88.68% |  Validation Loss:0.3038 , Accuracy:86.86%
Epoch:17, Training Loss: 0.2448, Accuracy: 89.58% |  Validation Loss:0.2906 , Accuracy:87.94%
Epoch:18, Training Loss: 0.2273, Accuracy: 90.34% |  Validation Loss:0.2915 , Accuracy:88.05%
Epoch:19, Training Loss: 0.1913, Accuracy: 91.97% |  Validation Loss:0.2889 , Accuracy:88.33%
Epoch:20, Training Loss: 0.2069, Accuracy: 91.22% |  Validation Loss:0.2894 , Accuracy:88.07%

光看列印不夠直觀,我們可以繪制訓練損失和驗證損失:

訓練和驗證損失

還可以繪制訓練和驗證準確率的變化曲線:

訓練和驗證準確率

看起來模型還不錯,但是真正怎么樣還需要測驗之后才知道,我們現在來預測沒有看過的25000條記錄:

# 最后在測驗集上測驗
with no_grad():
   X_test, y_test = Tensor(X_test), Tensor(y_test)
   outputs = model(X_test)
   correct = np.sum(sigmoid(outputs).numpy().round() == y_test.numpy())
   accuracy = 100 * correct / len(y_test)
   print(f"Test Accuracy:{accuracy}")
Test Accuracy:88.004

嗯,我們直接純手寫實作的前饋網路模型和Keras的前饋模型表現差不多1,還可以!

完整代碼

完整代碼筆者上傳到了程式員最大交友網站上去了,地址: 👉 https://github.com/nlp-greyfoss/metagrad

References


  1. Python深度學習 ??

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/429770.html

標籤:AI

上一篇:【原理+代碼】Python實作Topsis分析法(優劣解距離法)

下一篇:PointPillars論文決議和OpenPCDet代碼決議

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more