主頁 >  其他 > 【圖神經網路】ChebyNet-切比雪夫多項式近似圖卷積核

【圖神經網路】ChebyNet-切比雪夫多項式近似圖卷積核

2020-10-03 10:48:01 其他

本文為圖神經網路學習筆記,講解 ChebyNet-切比雪夫多項式近似圖卷積核,歡迎在評論區與我交流👏

ChebyNet 簡介

見【圖卷積網路】,

ChebyNet 實作

對圖的鄰接矩陣進行歸一化處理得到拉普拉斯矩陣,歸一化方法有:
{ L = D ? A L s y m = D ? 1 / 2 L D ? 1 / 2 L r w = D ? 1 L \left\{ \begin{array}{rcl} L=D-A \\ L^{sym}=D^{-1/2}LD^{-1/2}\\ L^{rw}=D^{-1}L \end{array} \right. ????L=D?ALsym=D?1/2LD?1/2Lrw=D?1L?
根據得到的歸一化拉普拉斯矩陣計算:
L ^ = 2 λ m a x L ? I N \hat{L}=\frac{2}{\lambda_{max}}L-I_N L^=λmax?2?L?IN?
Re-scaled 特征值對角矩陣,將其變換到 [ ? 1 , 1 ] [-1,1] [?1,1] 之間:

num_nodes = x.shape[0]
norm_edge_index, norm_edge_weight = chebnet_norm_edge(edge_index, num_nodes, edge_weight, lambda_max, normalization_type=normalization_type)                                            

利用切比雪夫多項式的迭代定義遞推計算高階項(節省大量運算),最后輸出模型結果,即多項式和 y = σ ( ∑ k = 0 K θ k T k ( L ^ ) ( x ) ) y=\sigma(\sum\limits_{k=0}^K\theta_kT_k(\hat{L})(x)) y=σ(k=0K?θk?Tk?(L^)(x)) 計算損失或評估模型效果:

T0_x = x
T1_x = x
out = tf.matmul(T0_x, kernel[0]) # 兩個矩陣相乘 

if K > 1:
    T1_x = aggregate_neighbors(x, norm_edge_index, norm_edge_weight, gcn_mapper, sum_reducer, identity_updater)
    out += tf.matmul(T1_x, kernel[1])

# T_{n+1}=2T_n-T_{n-1}
for i in range(2, K):
    T2_x = aggregate_neighbors(T1_x, norm_edge_index, norm_edge_weight, gcn_mapper, sum_reducer, identity_updater)  # L^T_{k-1}(L^)
    T2_x = 2.0 * T2_x - T0_x
    out += tf.matmul(T2_x, kernel[i])

    T0_x, T1_x = T1_x, T2_x

if bias is not None:
    out += bias

if activation is not None:
    out += activation(out)

return out

模型構建

本教程使用的核心庫是 tf_geometric,我們用它來進行圖資料匯入、圖資料預處理及圖神經網路構建,ChebNet 的具體實作已經在上面詳細介紹,LaplacianMaxEigenvalue 獲取拉普拉斯矩陣的最大特征值,后面使用 keras.metrics.Accuracy 評估模型性能:

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tf_geometric.layers.conv.chebnet import chebNet
from tf_geometric.datasets.cora import CoraDataset
from tf_geometric.utils.graph_utils import LaplacianMaxEigenvalue
from tqdm import tqdm

使用 tf_geometric 自帶的圖結構資料介面加載 Cora 資料集:

# 加載 Cora 資料集
graph, (train_index, valid_index, test_index) = CoraDataset().load_data()

獲取圖拉普拉斯矩陣的最大特征值:

# 獲取 lambda_max
graph_lambda_max = LaplacianMaxEigenvalue(graph.x, graph.edge_index, graph.edge_weight)

定義模型,引入 keras.layers 中的 Dropout 層隨機關閉神經元緩解過擬合,由于 Dropout 層在訓練和預測階段的狀態不同,通過引數 training 來決定是否需要 Dropout 發揮作用:

model = chebNet(64, K=3, lambda_max=graph_lambda_max()
fc = tf.keras.Sequential([
    keras.layers.Dropout(0.5), # Dropout 層隨機關閉神經元緩解過擬合
    keras.layers.Dense(num_classes)])

def forward(graph, training=False):
    h = model([graph.x, graph.edge_index, graph.edge_weight])
    h = fc(h, training=training) # 通過引數 training 來決定是否需要 Dropout 發揮作用
    return h

ChebyNet 訓練

模型的訓練與其他基于 Tensorflow 框架的模型訓練基本一致,主要步驟有定義優化器,計算誤差與梯度,反向傳播等,然后分別計算驗證集和測驗集上的準確率:

# 定義優化器
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)

best_test_acc = tmp_valid_acc = 0
for step in tqdm(range(1, 101)):
    with tf.GradientTape() as tape:
      	# 前向傳播
        logits = forward(graph, training=True)
        # 計算損失
        loss = compute_loss(logits, train_index, tape.watched_variables())

    vars = tape.watched_variables()
    grads = tape.gradient(loss, vars) # 計算梯度
    optimizer.apply_gradients(zip(grads, vars)) # 梯度下降優化

    valid_acc = evaluate(valid_index) # 計算驗證集
    test_acc = evaluate(test_index) # 計算測驗集
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        tmp_valid_acc = valid_acc
    print("step = {}\tloss = {}\tvalid_acc = {}\tbest_test_acc = {}".format(step, loss, tmp_valid_acc, best_test_acc))

用交叉熵損失函式計算模型損失,注意在加載 Cora 資料集時,回傳值是整個圖資料以及相應的 train_indexvalid_indextest_index,TAGCN 在訓練時輸入整個Graph,計算損失時通過 train_index 計算模型在訓練集上的迭代損失,因此,此時傳入的 mask_indextrain_index,由于是多分類任務,需要將節點的標簽轉換為 one-hot 向量以便與模型輸出的結果維度對應,由于圖神經模型在小資料集上很容易過擬合,所以這里用 L 2 L_2 L2? 正則化緩解過擬合:

def compute_loss(logits, mask_index, vars):
    masked_logits = tf.gather(logits, mask_index) # 前向傳播(預測)的結果,取訓練資料部分
    masked_labels = tf.gather(graph.y, mask_index) # 真實結果,取訓練資料部分
    losses = tf.nn.softmax_cross_entropy_with_logits(
        logits=masked_logits, # 預測結果
        labels=tf.one_hot(masked_labels, depth=num_classes) # 真實結果,即標簽
    )
		# 用 L_2 正則化緩解過擬合
    kernel_vals = [var for var in vars if "kernel" in var.name]
    l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vals]

    # reduce_mean 計算張量的平均值;tf.add_n 串列對應元素相加
    return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4

ChebyNet 評估

評估模型性能時只需傳入 valid_masktest_mask,通過 tf.gather 函式可以拿出驗證集或測驗集在模型上的預測結果與真實標簽,用 keras自帶的 keras.metrics.Accuracy 計算準確率:

def evaluate(mask):
    logits = forward(graph) # 前向傳播結果
    logits = tf.nn.log_softmax(logits, axis=-1) # 假設函式處理
    masked_logits = tf.gather(logits, mask) # 預測結果
    masked_labels = tf.gather(graph.y, mask) # 真實標簽

    # 回傳預測結果向量最大值的索引
    y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32)

    accuracy_m = keras.metrics.Accuracy()
    accuracy_m.update_state(masked_labels, y_pred)
    return accuracy_m.result().numpy() # 準確度結果轉換為 numpy 回傳

運行結果

 0%|          | 0/100 [00:00<?, ?it/s]step = 1	loss = 1.9817407131195068	valid_acc = 0.7139999866485596	best_test_acc = 0.7089999914169312
  2%|▏         | 2/100 [00:01<00:55,  1.76it/s]step = 2	loss = 1.6069653034210205	valid_acc = 0.75	best_test_acc = 0.7409999966621399
step = 3	loss = 1.2625869512557983	valid_acc = 0.7720000147819519	best_test_acc = 0.7699999809265137
  4%|▍         | 4/100 [00:01<00:48,  1.98it/s]step = 4	loss = 0.9443040490150452	valid_acc = 0.7760000228881836	best_test_acc = 0.7749999761581421
  5%|▌         | 5/100 [00:02<00:46,  2.06it/s]step = 5	loss = 0.7023431062698364	valid_acc = 0.7760000228881836	best_test_acc = 0.7770000100135803
  ...
96	loss = 0.0799005851149559	valid_acc = 0.7940000295639038	best_test_acc = 0.8080000281333923
 96%|█████████▌| 96/100 [00:43<00:01,  2.31it/s]step = 97	loss = 0.0768655389547348	valid_acc = 0.7940000295639038	best_test_acc = 0.8080000281333923
 97%|█████████▋| 97/100 [00:43<00:01,  2.33it/s]step = 98	loss = 0.0834992527961731	valid_acc = 0.7940000295639038	best_test_acc = 0.8080000281333923
 99%|█████████▉| 99/100 [00:44<00:00,  2.34it/s]step = 99	loss = 0.07315651327371597	valid_acc = 0.7940000295639038	best_test_acc = 0.8080000281333923
100%|██████████| 100/100 [00:44<00:00,  2.23it/s]
step = 100	loss = 0.07698118686676025	valid_acc = 0.7940000295639038	best_test_acc = 0.8080000281333923

完整代碼見【demo_chebynet.py】,

有幫助的話點個贊加關注吧 😃

參考

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/151157.html

標籤:AI

上一篇:2020-10-02:golang如何寫一個插件?

下一篇:安裝部署WEB安全測驗用靶機(AWVA MCIR Pikachu mutillidae bWAPP)

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more