主頁 >  其他 > 用PyTorch重新創建Keras API

用PyTorch重新創建Keras API

2020-10-28 07:05:44 其他

作者|Bipin Krishnan P
編譯|VK
來源|Towards Data Science

介紹

Francois Chollet寫的《Deep Learning with Python》一書讓我進入了深度學習的世界,從那時起我就愛上了Keras的風格,

Keras是我的第一個框架,然后是Tensorflow,接著進入PyTorch,

老實說,在Keras的模型訓練中,我很興奮這個進度條,真是太棒了,

那么,為什么不嘗試把Keras訓練模型的經驗帶到PyTorch呢?

這個問題讓我開始了作業,最后我用所有那些花哨的進度條重現了Keras的Dense層、卷積層和平坦層,

模型可以通過堆疊一層到另一層來創建,并通過簡單地呼叫fit方法進行訓練,該方法類似于Keras的作業方式,

Keras的作業方式如下:

#一層一層疊起來
#采用輸入資料的形狀
inputs = keras.Input(shape=(784,))
l1 = layers.Dense(64, activation="relu")(inputs)
l2 = layers.Dense(64, activation="relu")(l1)
outputs = layers.Dense(10)(l2)

model = keras.Model(inputs=inputs, outputs=outputs)

#輸出模型摘要
model.summary()

#模型訓練和評估
model.fit(x_train, y_train, epochs=2)
model.evaluate(x_test, y_test)

1.匯入所需的庫

你可能不熟悉庫pkbar,它用于顯示類似Keras的進度條,

!pip install pkbar
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torchsummary import summary as summary_
import pkbar

import warnings
warnings.filterwarnings('ignore')

2.輸入層和dense層

輸入層只是以資料的單一實體的形式被傳遞到神經網路并回傳它,對于全連接的網路,它將類似于(1,784),對于卷積神經網路,它將是影像的尺寸(高度×寬度×通道),

使用大寫字母來命名python函式是違反規則的,但是我們暫時忽略它(Keras源代碼的某些部分使用相同的約定),

def Input(shape):
  Input.shape = shape
  return Input.shape

def get_conv_output(shape, inputs):
  bs = 1
  data = https://www.cnblogs.com/panchuangai/p/Variable(torch.rand(bs, *shape))
  output_feat = inputs(data)

  return output_feat.size(1)

def same_pad(h_in, kernal, stride, dilation):
  return (stride*(h_in-1)-h_in+(dilation*(kernal-1))+1) / 2.0

Dense類通過傳遞該層的輸出神經元數量和激活函式來初始化,呼叫Dense層時,前一層作為輸入傳遞,

現在我們有了關于前一層的資訊,如果前一層是輸入層,則創建一個PyTorch線性層,其中輸入層回傳的形狀和輸出神經元的數量作為Dense類初始化期間的引數,

如果前一層是Dense層,我們通過在Dense類中增加一個PyTorch線性層和一個激活層來擴展神經網路,

如果前一層是卷積層或平坦層,我們將創建一個名為get_conv_output()的實用函式,通過卷積層和平坦層得到影像的輸出形狀,此維度是必需的,因為如果不向in_features引數傳遞值,則無法在PyTorch中創建線性層,

函式的作用是將影像形狀和卷積神經網路模型作為輸入,然后,它創建一個與影像形狀相同的虛擬張量,并將其傳遞給卷積網路(具有平坦層),并回傳從中輸出的資料的大小,該大小作為值傳遞給PyTorch線性層中的in_features引數,

class Dense(nn.Module):
  def __init__(self, outputs, activation):
    super().__init__()
    self.outputs = outputs
    self.activation = activation

  def __call__(self, inputs):
    self.inputs_size = 1
    
    if type(inputs) == tuple:
      for i in range(len(inputs)):
        self.inputs_size *= inputs[i]
      
      self.layers = nn.Sequential(
        nn.Linear(self.inputs_size, self.outputs),
        self.activation
    )

      return self.layers

    elif isinstance(inputs[-2], nn.Linear):
      self.inputs = inputs
      self.layers = list(self.inputs)
      self.layers.extend([nn.Linear(self.layers[-2].out_features, self.outputs), self.activation])

      self.layers = nn.Sequential(*self.layers)

      return self.layers

    else:
      self.inputs = inputs
      self.layers = list(self.inputs)
      self.layers.extend([nn.Linear(get_conv_output(Input.shape, self.inputs), self.outputs), self.activation])

      self.layers = nn.Sequential(*self.layers)

      return self.layers

3.平坦層

為了創建一個平坦層,我們將創建一個名為FlattenedLayer的自定義層類,它接受張量作為輸入,并在前向傳播期間回傳張量的平坦版本,

我們將創建另一個名為flatten的類,當呼叫這個層時,前面的層作為輸入傳遞,然后flatten類通過在前面的層上添加我們自定義創建的FlattenedLayer類來擴展網路,

因此,所有到達平坦層的資料都是使用我們自定義創建的平坦層進行平坦的,

class FlattenedLayer(nn.Module):
  def __init__(self):
    super().__init__()
    pass

  def forward(self, input):
      self.inputs = input.view(input.size(0), -1)
      return self.inputs


class Flatten():
  def __init__(self):
    pass

  def __call__(self, inputs):
    self.inputs = inputs
    self.layers = list(self.inputs)
    self.layers.extend([FlattenedLayer()])
    self.layers = nn.Sequential(*self.layers)

    return self.layers

4.卷積層

我們將通過傳入濾波器數量、內核大小、步長、填充、膨脹和激活函式來初始化Conv2d層,

現在,當呼叫Conv2d層時,前面的層被傳遞給它,如果前一層是Input layer,則是一個Pytorch conv2d層,其中提供了濾波器數量、內核大小、步長、填充,擴張和激活函式被創建,其中in_channels的值取自輸入形狀中的通道數,

如果前一層是卷積層,則通過添加一個PyTorch Conv2d層和激活函式來擴展前一層,激活函式的值取自前一層的out_channels ,

在填充的情況下,如果用戶需要保留從該層傳出的資料的維度,則可以將padding的值指定為“same”,而不是整數,

如果padding的值被指定為“same”,那么將使用一個名為same_pad()的實用函式來獲取padding的值,以保留給定輸入大小、內核大小、步長和膨脹的維度,

可以使用前面討論的get_conv_output()實用程式函式獲得輸入大小,

class Conv2d(nn.Module):
  def __init__(self, filters, kernel_size, strides, padding, dilation, activation):
    super().__init__()
    self.filters = filters
    self.kernel = kernel_size
    self.strides = strides
    self.padding = padding
    self.dilation = dilation
    self.activation = activation

  def __call__(self, inputs):

    if type(inputs) == tuple:
      self.inputs_size = inputs

      if self.padding == 'same':
        self.padding = int(same_pad(self.inputs_size[-2], self.kernel, self.strides, self.dilation))
      else:
        self.padding = self.padding

      self.layers = nn.Sequential(
        nn.Conv2d(self.inputs_size[-3],
                  self.filters, 
                  self.kernel, 
                  self.strides, 
                  self.padding,
                  self.dilation),
        self.activation
    )

      return self.layers

    else:
      if self.padding == 'same':
        self.padding = int(same_pad(get_conv_output(Input.shape, inputs), self.kernel, self.strides, self.dilation))
      else:
        self.padding = self.padding

      self.inputs = inputs
      self.layers = list(self.inputs)
      self.layers.extend(
             [nn.Conv2d(self.layers[-2].out_channels, 
                    self.filters, 
                    self.kernel, 
                    self.strides, 
                    self.padding,
                    self.dilation),
             self.activation]
          )
      self.layers = nn.Sequential(*self.layers)

      return self.layers

5.模型類

在構建了模型的體系結構之后,通過傳入輸入層和輸出層來初始化模型類,但是我已經給出了一個額外的引數,名為device,它在Keras中不存在,這個引數接受值為'CPU'或'CUDA',它將把整個模型移動到指定的設備,

model類的parameters方法用于回傳要給PyTorch優化器的模型引數,

model類有一個名為compile的方法,它接受訓練模型所需的優化器和丟失函式,模型類的摘要方法是借助torch的summary庫顯示所創建模型的摘要,

采用擬合方法對模型進行訓練,該方法以輸入特征集、目標資料集和epoch數為引數,它顯示由損失函式計算的損失和使用pkbar庫的訓練進度,

評估會計算驗證資料集的損失和精度,

當使用PyTorch資料加載程式加載資料時,將使用fit_generator、evaluate_generator 和predict_generator ,fit_generator 以訓練集資料加載器和epoch作為引數,evaluate_generator和predict_generator分別使用驗證集資料加載器和測驗資料加載器來衡量模型對未查看資料的執行情況,

class Model():
  def __init__(self, inputs, outputs, device):
    self.input_size = inputs
    self.device = device
    self.model = outputs.to(self.device)

  def parameters(self):
    return self.model.parameters()

  def compile(self, optimizer, loss):
    self.opt = optimizer
    self.criterion = loss

  def summary(self):
    summary_(self.model, self.input_size, device=self.device)
    print("Device Type:", self.device)

  def fit(self, data_x, data_y, epochs):
    self.model.train()

    for epoch in range(epochs):
      print("Epoch {}/{}".format(epoch+1, epochs))
      progress = pkbar.Kbar(target=len(data_x), width=25)
      
      for i, (data, target) in enumerate(zip(data_x, data_y)):
        self.opt.zero_grad()

        train_out = self.model(data.to(self.device))
        loss = self.criterion(train_out, target.to(self.device))
        loss.backward()

        self.opt.step()

        progress.update(i, values=[("loss: ", loss.item())])

      progress.add(1)

  def evaluate(self, test_x, test_y):
    self.model.eval()
    correct, loss = 0.0, 0.0

    progress = pkbar.Kbar(target=len(test_x), width=25)

    for i, (data, target) in enumerate(zip(test_x, test_y)):
      out = self.model(data.to(self.device))
      loss += self.criterion(out, target.to(self.device))

      correct += ((torch.max(out, 1)[1]) == target.to(self.device)).sum()

      progress.update(i, values=[("loss", loss.item()/len(test_x)), ("acc", (correct/len(test_x)).item())])
    progress.add(1)


  def fit_generator(self, generator, epochs):
    self.model.train()

    for epoch in range(epochs):
      print("Epoch {}/{}".format(epoch+1, epochs))
      progress = pkbar.Kbar(target=len(generator), width=25)

      for i, (data, target) in enumerate(generator):
        self.opt.zero_grad()

        train_out = self.model(data.to(self.device))
        loss = self.criterion(train_out.squeeze(), target.to(self.device))
        loss.backward()

        self.opt.step()

        progress.update(i, values=[("loss: ", loss.item())])

      progress.add(1)
      

  def evaluate_generator(self, generator):
    self.model.eval()
    correct, loss = 0.0, 0.0

    progress = pkbar.Kbar(target=len(generator), width=25)

    for i, (data, target) in enumerate(generator):
      out = self.model(data.to(self.device))
      loss += self.criterion(out.squeeze(), target.to(self.device))

      correct += (torch.max(out.squeeze(), 1)[1] == target.to(self.device)).sum()

      progress.update(i, values=[("test_acc", (correct/len(generator)).item()), ("test_loss", loss.item()/len(generator))])

    progress.add(1)

  def predict_generator(self, generator):
    self.model.train()
    out = []
    for i, (data, labels) in enumerate(generator):
      out.append(self.model(data.to(self.device)))

    return out

結尾

我用Dense層和卷積神經網路在CIFAR100、CIFAR10和MNIST資料集上測驗了代碼,它作業得很好,但還有很大的改進空間,

這是一個有趣的專案,我已經作業了3-4天,它真的突破了我用PyTorch編程的極限,

你可以在這里查看完整的代碼,并在上面提到的資料集上進行訓練,或者你可以自由地調整代碼以適合你在colab中的喜好:https://colab.research.google.com/github/bipinKrishnan/torchkeras/blob/master/functional_api_v1.ipynb

原文鏈接:https://towardsdatascience.com/recreating-keras-functional-api-with-pytorch-cc2974f7143c

歡迎關注磐創AI博客站:
http://panchuang.net/

sklearn機器學習中文官方檔案:
http://sklearn123.com/

歡迎關注磐創博客資源匯總站:
http://docs.panchuang.net/

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/194814.html

標籤:其他

上一篇:回圈神經網路

下一篇:Torch:從特征提取到模型的語音識別

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more