主頁 >  其他 > 用 Java 訓練出一只“不死鳥”

用 Java 訓練出一只“不死鳥”

2020-12-30 13:16:28 其他

作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手機游戲,因其簡單的玩法但極度困難的設定迅速走紅全網,隨著深度學習(DL)與增強學習(RL)等前沿演算法的發展,我們可以使用 Java 非常方便地訓練出一個智能體來控制 Flappy Bird,

故事開始于《GitHub 上的大佬們打完招呼,會聊些什么?》,今天我們就來一起看一下如何用 Java 訓練出一個不死鳥,游戲專案我們使用了一個僅用 Java 基本類別庫撰寫的 FlappyBird 游戲,在訓練方面,我們使用 DeepJavaLibrary 一個基于 Java 的深度學習框架來構建增強學習訓練網路并進行訓練,經過了300 萬步(四小時)的訓練后,小鳥已經可以獲得最高 8000 多分的成績,靈活穿梭于水管之間,

在本文中,我們將從原理開始一步一步實作增強學習演算法并用它對游戲進行訓練,如果任何一個時刻不清楚如何繼續進行下去,可以參閱專案的原始碼,

專案地址:https://github.com/kingyuluk/RL-FlappyBird

增強學習(RL)的架構

在這一節會介紹主要用到的演算法以及神經網路,幫助你更好的了解如何進行訓練,本專案與 DeepLearningFlappyBird 使用了類似的方法進行訓練,演算法整體的架構是 Q-Learning + 卷積神經網路(CNN),把游戲每一幀的狀態存盤起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經網路的訓練資料,

CNN 訓練簡述

CNN 的輸入資料為連續的 4 幀影像,我們將這影像 stack 起來作為小鳥當前的“observation”,影像會轉換成灰度圖以減少所需的訓練資源,影像存盤的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 陣列里的元素就是當前幀的像素值,這些資料將輸入到 CNN 后將輸出 (batch size, 2) 的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應的收益,

訓練資料

在小鳥采取動作后,我們會得到 preObservation and currentObservation 即是兩組 4 幀的連續的影像表示小鳥動作前和動作后的狀態,然后我們將 preObservation, currentObservation, action, reward, terminal 組成的五元組作為一個 step 存進 replayBuffer 中,它是一個有限大小的訓練資料集,他會隨著最新的操作動態更新內容,

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

訓練的三個周期

訓練分為 3 個不同的周期以更好地生成訓練資料:

  • Observe(觀察) 周期:隨機產生訓練資料

  • Explore (探索) 周期:隨機與推理動作結合更新訓練資料

  • Training (訓練) 周期:推理動作主導產生新資料

通過這種訓練模式,我們可以更好的達到預期效果,

處于 Explore 周期時,我們會根據權重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作,訓練前期,隨機動作的權重會非常大,因為模型的決策十分不準確 (甚至不如隨機),在訓練后期時,隨著模型學習的動作逐步增加,我們會不斷增加模型推理動作的權重并最終使它成為主導動作,調節隨機動作的引數叫做 epsilon 它會隨著訓練的程序不斷變化,

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

訓練邏輯

首先,我們會從 replayBuffer 中隨機抽取一批資料作為作為訓練集,然后將 preObservation 輸入到神經網路得到所有行為的 reward(Q)作為預測值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

postObservation 同樣會輸入到神經網路,根據馬爾科夫決策程序以及貝爾曼價值函式計算出所有行為的 reward(targetQ)作為真實值:

// 將 postInput 輸入到神經網路中得到 targetQReward 是 (batchsize,2) 的矩陣,根據 Q-learning 的演算法,每一次的 targetQ 需要根據當前環境是否結束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList,
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在訓練結束時,計算 Q 和 targetQ 的損失值,并在 CNN 中更新權重,

卷積神經網路模型(CNN)

我們采用了采用了 3 個卷積層,4 個 relu 激活函式以及 2 個全連接層的神經網路架構,

layerinput shapeoutput shape
conv2d(batchSize, 4, 80, 80)(batchSize,4,20,20)
conv2d(batchSize, 4, 20 ,20)(batchSize, 32, 9, 9)
conv2d(batchSize, 32, 9, 9)(batchSize, 64, 7, 7)
linear(batchSize, 3136)(batchSize, 512)
linear(batchSize, 512)(batchSize, 2)

訓練程序

DJL 的 RL 庫中提供了非常方便的用于實作強化學習的介面:(RlEnv, RlAgent, ReplayBuffer),

  • 實作 RlAgent 介面即可構建一個可以進行訓練的智能體,

  • 在現有的游戲環境中實作 RlEnv 介面即可生成訓練所需的資料,

  • 創建 ReplayBuffer 可以存盤并動態更新訓練資料,

在實作這些介面后,只需要呼叫 step 方法:

RlEnv.step(action, training);

這個方法會將 RlAgent 決策出的動作輸入到游戲環境中獲得反饋,我們可以在 RlEnv 中提供的 runEnviroment 方法中呼叫 step 方法,然后只需要重復執行 runEnvironment 方法,即可不斷地生成用于訓練的資料,

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

我們將 ReplayBuffer 可存盤的 step 數量設定為 50000,在 observe 周期我們會先向 replayBuffer 中存盤 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學習,

在 explore 和 training 周期,神經網路會隨機從 replayBuffer 中生成訓練集并將它們輸入到模型中訓練,我們使用 Adam 優化器和 MSE 損失函式迭代神經網路,

神經網路輸入預處理

首先將影像大小 resize 成 80x80 并轉為灰度圖,這有助于在不丟失資訊的情況下提高訓練速度,

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

然后我們把連續的四幀影像作為一個輸入,為了獲得連續四幀的連續影像,我們維護了一個全域的影像佇列保存游戲執行緒中的影像,每一次動作后替換掉最舊的一幀,然后把佇列里的影像 stack 成一個單獨的 NDArray,

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

一旦以上部分完成,我們就可以開始訓練了,訓練優化為了獲得最佳的訓練性能,我們關閉了 GUI 以加快樣本生成速度,并使用 Java 多執行緒將訓練回圈和樣本生成回圈分別在不同的執行緒中運行,

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

總結

這個模型在 NVIDIA T4 GPU 訓練了大概 4 個小時,更新了 300 萬步,訓練后的小鳥已經可以完全自主控制動作靈活穿梭于管道之間,訓練后的模型也同樣上傳到了倉庫中供您測驗,在此專案中 DJL 提供了強大的訓練 API 以及模型庫支持,使得在 Java 開發程序中得心應手,

本專案完整代碼:https://github.com/kingyuluk/RL-FlappyBird



關注后點擊“入伙”,加入我們

「閱讀原文,點個 star 吧」

削微寒 CSDN認證博客專家 萌新
微信搜【HelloGitHub】關注后可以找到我,分享 GitHub 上有趣、入門級的開源專案,每月 28 號更新,已持續維護 4 年有余,

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/242456.html

標籤:其他

上一篇:GitHub 熱榜:這款超硬核的 OCR 開源工具,我給 99.99 分!

下一篇:后疫情時代2020年后游戲行業可能會有哪些發展趨勢?

標籤雲
其他(157675) Python(38076) JavaScript(25376) Java(17977) C(15215) 區塊鏈(8255) C#(7972) AI(7469) 爪哇(7425) MySQL(7132) html(6777) 基礎類(6313) sql(6102) 熊猫(6058) PHP(5869) 数组(5741) R(5409) Linux(5327) 反应(5209) 腳本語言(PerlPython)(5129) 非技術區(4971) Android(4554) 数据框(4311) css(4259) 节点.js(4032) C語言(3288) json(3245) 列表(3129) 扑(3119) C++語言(3117) 安卓(2998) 打字稿(2995) VBA(2789) Java相關(2746) 疑難問題(2699) 细绳(2522) 單片機工控(2479) iOS(2429) ASP.NET(2402) MongoDB(2323) 麻木的(2285) 正则表达式(2254) 字典(2211) 循环(2198) 迅速(2185) 擅长(2169) 镖(2155) 功能(1967) .NET技术(1958) Web開發(1951) python-3.x(1918) HtmlCss(1915) 弹簧靴(1913) C++(1909) xml(1889) PostgreSQL(1872) .NETCore(1853) 谷歌表格(1846) Unity3D(1843) for循环(1842)

熱門瀏覽
  • 網閘典型架構簡述

    網閘架構一般分為兩種:三主機的三系統架構網閘和雙主機的2+1架構網閘。 三主機架構分別為內端機、外端機和仲裁機。三機無論從軟體和硬體上均各自獨立。首先從硬體上來看,三機都用各自獨立的主板、記憶體及存盤設備。從軟體上來看,三機有各自獨立的作業系統。這樣能達到完全的三機獨立。對于“2+1”系統,“2”分為 ......

    uj5u.com 2020-09-10 02:00:44 more
  • 如何從xshell上傳檔案到centos linux虛擬機里

    如何從xshell上傳檔案到centos linux虛擬機里及:虛擬機CentOs下執行 yum -y install lrzsz命令,出現錯誤:鏡像無法找到軟體包 前言 一、安裝lrzsz步驟 二、上傳檔案 三、遇到的問題及解決方案 總結 前言 提示:其實很簡單,往虛擬機上安裝一個上傳檔案的工具 ......

    uj5u.com 2020-09-10 02:00:47 more
  • 一、SQLMAP入門

    一、SQLMAP入門 1、判斷是否存在注入 sqlmap.py -u 網址/id=1 id=1不可缺少。當注入點后面的引數大于兩個時。需要加雙引號, sqlmap.py -u "網址/id=1&uid=1" 2、判斷文本中的請求是否存在注入 從文本中加載http請求,SQLMAP可以從一個文本檔案中 ......

    uj5u.com 2020-09-10 02:00:50 more
  • Metasploit 簡單使用教程

    metasploit 簡單使用教程 浩先生, 2020-08-28 16:18:25 分類專欄: kail 網路安全 linux 文章標簽: linux資訊安全 編輯 著作權 metasploit 使用教程 前言 一、Metasploit是什么? 二、準備作業 三、具體步驟 前言 Msfconsole ......

    uj5u.com 2020-09-10 02:00:53 more
  • 游戲逆向之驅動層與用戶層通訊

    驅動層代碼: #pragma once #include <ntifs.h> #define add_code CTL_CODE(FILE_DEVICE_UNKNOWN,0x800,METHOD_BUFFERED,FILE_ANY_ACCESS) /* 更多游戲逆向視頻www.yxfzedu.com ......

    uj5u.com 2020-09-10 02:00:56 more
  • 北斗電力時鐘(北斗授時服務器)讓網路資料更精準

    北斗電力時鐘(北斗授時服務器)讓網路資料更精準 北斗電力時鐘(北斗授時服務器)讓網路資料更精準 京準電子科技官微——ahjzsz 近幾年,資訊技術的得了快速發展,互聯網在逐漸普及,其在人們生活和生產中都得到了廣泛應用,并且取得了不錯的應用效果。計算機網路資訊在電力系統中的應用,一方面使電力系統的運行 ......

    uj5u.com 2020-09-10 02:01:03 more
  • 【CTF】CTFHub 技能樹 彩蛋 writeup

    ?碎碎念 CTFHub:https://www.ctfhub.com/ 筆者入門CTF時時剛開始刷的是bugku的舊平臺,后來才有了CTFHub。 感覺不論是網頁UI設計,還是題目質量,賽事跟蹤,工具軟體都做得很不錯。 而且因為獨到的金幣制度的確讓人有一種想去刷題賺金幣的感覺。 個人還是非常喜歡這個 ......

    uj5u.com 2020-09-10 02:04:05 more
  • 02windows基礎操作

    我學到了一下幾點 Windows系統目錄結構與滲透的作用 常見Windows的服務詳解 Windows埠詳解 常用的Windows注冊表詳解 hacker DOS命令詳解(net user / type /md /rd/ dir /cd /net use copy、批處理 等) 利用dos命令制作 ......

    uj5u.com 2020-09-10 02:04:18 more
  • 03.Linux基礎操作

    我學到了以下幾點 01Linux系統介紹02系統安裝,密碼啊破解03Linux常用命令04LAMP 01LINUX windows: win03 8 12 16 19 配置不繁瑣 Linux:redhat,centos(紅帽社區版),Ubuntu server,suse unix:金融機構,證券,銀 ......

    uj5u.com 2020-09-10 02:04:30 more
  • 05HTML

    01HTML介紹 02頭部標簽講解03基礎標簽講解04表單標簽講解 HTML前段語言 js1.了解代碼2.根據代碼 懂得挖掘漏洞 (POST注入/XSS漏洞上傳)3.黑帽seo 白帽seo 客戶網站被黑帽植入劫持代碼如何處理4.熟悉html表單 <html><head><title>TDK標題,描述 ......

    uj5u.com 2020-09-10 02:04:36 more
最新发布
  • 2023年最新微信小程式抓包教程

    01 開門見山 隔一個月發一篇文章,不過分。 首先回顧一下《微信系結手機號資料庫被脫庫事件》,我也是第一時間得知了這個訊息,然后跟蹤了整件事情的經過。下面是這起事件的相關截圖以及近日流出的一萬條資料樣本: 個人認為這件事也沒什么,還不如關注一下之前45億快遞資料查詢渠道疑似在近日復活的訊息。 訊息是 ......

    uj5u.com 2023-04-20 08:48:24 more
  • web3 產品介紹:metamask 錢包 使用最多的瀏覽器插件錢包

    Metamask錢包是一種基于區塊鏈技術的數字貨幣錢包,它允許用戶在安全、便捷的環境下管理自己的加密資產。Metamask錢包是以太坊生態系統中最流行的錢包之一,它具有易于使用、安全性高和功能強大等優點。 本文將詳細介紹Metamask錢包的功能和使用方法。 一、 Metamask錢包的功能 數字資 ......

    uj5u.com 2023-04-20 08:47:46 more
  • vulnhub_Earth

    前言 靶機地址->>>vulnhub_Earth 攻擊機ip:192.168.20.121 靶機ip:192.168.20.122 參考文章 https://www.cnblogs.com/Jing-X/archive/2022/04/03/16097695.html https://www.cnb ......

    uj5u.com 2023-04-20 07:46:20 more
  • 從4k到42k,軟體測驗工程師的漲薪史,給我看哭了

    清明節一過,盲猜大家已經無心上班,在數著日子準備過五一,但一想到銀行卡里的余額……瞬間心情就不美麗了。最近,2023年高校畢業生就業調查顯示,本科畢業月平均起薪為5825元。調查一出,便有很多同學表示自己又被平均了。看著這一資料,不免讓人想到前不久中國青年報的一項調查:近六成大學生認為畢業10年內會 ......

    uj5u.com 2023-04-20 07:44:00 more
  • 最新版本 Stable Diffusion 開源 AI 繪畫工具之中文自動提詞篇

    🎈 標簽生成器 由于輸入正向提示詞 prompt 和反向提示詞 negative prompt 都是使用英文,所以對學習母語的我們非常不友好 使用網址:https://tinygeeker.github.io/p/ai-prompt-generator 這個網址是為了讓大家在使用 AI 繪畫的時候 ......

    uj5u.com 2023-04-20 07:43:36 more
  • 漫談前端自動化測驗演進之路及測驗工具分析

    隨著前端技術的不斷發展和應用程式的日益復雜,前端自動化測驗也在不斷演進。隨著 Web 應用程式變得越來越復雜,自動化測驗的需求也越來越高。如今,自動化測驗已經成為 Web 應用程式開發程序中不可或缺的一部分,它們可以幫助開發人員更快地發現和修復錯誤,提高應用程式的性能和可靠性。 ......

    uj5u.com 2023-04-20 07:43:16 more
  • CANN開發實踐:4個DVPP記憶體問題的典型案例解讀

    摘要:由于DVPP媒體資料處理功能對存放輸入、輸出資料的記憶體有更高的要求(例如,記憶體首地址128位元組對齊),因此需呼叫專用的記憶體申請介面,那么本期就分享幾個關于DVPP記憶體問題的典型案例,并給出原因分析及解決方法。 本文分享自華為云社區《FAQ_DVPP記憶體問題案例》,作者:昇騰CANN。 DVPP ......

    uj5u.com 2023-04-20 07:43:03 more
  • msf學習

    msf學習 以kali自帶的msf為例 一、msf核心模塊與功能 msf模塊都放在/usr/share/metasploit-framework/modules目錄下 1、auxiliary 輔助模塊,輔助滲透(埠掃描、登錄密碼爆破、漏洞驗證等) 2、encoders 編碼器模塊,主要包含各種編碼 ......

    uj5u.com 2023-04-20 07:42:59 more
  • Halcon軟體安裝與界面簡介

    1. 下載Halcon17版本到到本地 2. 雙擊安裝包后 3. 步驟如下 1.2 Halcon軟體安裝 界面分為四大塊 1. Halcon的五個助手 1) 影像采集助手:與相機連接,設定相機引數,采集影像 2) 標定助手:九點標定或是其它的標定,生成標定檔案及內參外參,可以將像素單位轉換為長度單位 ......

    uj5u.com 2023-04-20 07:42:17 more
  • 在MacOS下使用Unity3D開發游戲

    第一次發博客,先發一下我的游戲開發環境吧。 去年2月份買了一臺MacBookPro2021 M1pro(以下簡稱mbp),這一年來一直在用mbp開發游戲。我大致分享一下我的開發工具以及使用體驗。 1、Unity 官網鏈接: https://unity.cn/releases 我一般使用的Apple ......

    uj5u.com 2023-04-20 07:40:19 more