作者:Kingyu & Lanking
FlappyBird 是 2013 年推出的一款手機游戲,因其簡單的玩法但極度困難的設定迅速走紅全網,隨著深度學習(DL)與增強學習(RL)等前沿演算法的發展,我們可以使用 Java 非常方便地訓練出一個智能體來控制 Flappy Bird,
故事開始于《GitHub 上的大佬們打完招呼,會聊些什么?》,今天我們就來一起看一下如何用 Java 訓練出一個不死鳥,游戲專案我們使用了一個僅用 Java 基本類別庫撰寫的 FlappyBird 游戲,在訓練方面,我們使用 DeepJavaLibrary 一個基于 Java 的深度學習框架來構建增強學習訓練網路并進行訓練,經過了300 萬步(四小時)的訓練后,小鳥已經可以獲得最高 8000 多分的成績,靈活穿梭于水管之間,
在本文中,我們將從原理開始一步一步實作增強學習演算法并用它對游戲進行訓練,如果任何一個時刻不清楚如何繼續進行下去,可以參閱專案的原始碼,
專案地址:https://github.com/kingyuluk/RL-FlappyBird
增強學習(RL)的架構
在這一節會介紹主要用到的演算法以及神經網路,幫助你更好的了解如何進行訓練,本專案與 DeepLearningFlappyBird 使用了類似的方法進行訓練,演算法整體的架構是 Q-Learning + 卷積神經網路(CNN),把游戲每一幀的狀態存盤起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經網路的訓練資料,
CNN 訓練簡述
CNN 的輸入資料為連續的 4 幀影像,我們將這影像 stack 起來作為小鳥當前的“observation”,影像會轉換成灰度圖以減少所需的訓練資源,影像存盤的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 陣列里的元素就是當前幀的像素值,這些資料將輸入到 CNN 后將輸出 (batch size, 2) 的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應的收益,
訓練資料
在小鳥采取動作后,我們會得到 preObservation and currentObservation 即是兩組 4 幀的連續的影像表示小鳥動作前和動作后的狀態,然后我們將 preObservation, currentObservation, action, reward, terminal 組成的五元組作為一個 step 存進 replayBuffer 中,它是一個有限大小的訓練資料集,他會隨著最新的操作動態更新內容,
public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}
訓練的三個周期
訓練分為 3 個不同的周期以更好地生成訓練資料:
Observe(觀察) 周期:隨機產生訓練資料
Explore (探索) 周期:隨機與推理動作結合更新訓練資料
Training (訓練) 周期:推理動作主導產生新資料
通過這種訓練模式,我們可以更好的達到預期效果,
處于 Explore 周期時,我們會根據權重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作,訓練前期,隨機動作的權重會非常大,因為模型的決策十分不準確 (甚至不如隨機),在訓練后期時,隨著模型學習的動作逐步增加,我們會不斷增加模型推理動作的權重并最終使它成為主導動作,調節隨機動作的引數叫做 epsilon 它會隨著訓練的程序不斷變化,
public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}
訓練邏輯
首先,我們會從 replayBuffer 中隨機抽取一批資料作為作為訓練集,然后將 preObservation 輸入到神經網路得到所有行為的 reward(Q)作為預測值:
NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));
postObservation 同樣會輸入到神經網路,根據馬爾科夫決策程序以及貝爾曼價值函式計算出所有行為的 reward(targetQ)作為真實值:
// 將 postInput 輸入到神經網路中得到 targetQReward 是 (batchsize,2) 的矩陣,根據 Q-learning 的演算法,每一次的 targetQ 需要根據當前環境是否結束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList,
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
在訓練結束時,計算 Q 和 targetQ 的損失值,并在 CNN 中更新權重,
卷積神經網路模型(CNN)
我們采用了采用了 3 個卷積層,4 個 relu 激活函式以及 2 個全連接層的神經網路架構,
| layer | input shape | output shape |
|---|---|---|
| conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
| conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
| conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
| linear | (batchSize, 3136) | (batchSize, 512) |
| linear | (batchSize, 512) | (batchSize, 2) |
訓練程序
DJL 的 RL 庫中提供了非常方便的用于實作強化學習的介面:(RlEnv, RlAgent, ReplayBuffer),
實作 RlAgent 介面即可構建一個可以進行訓練的智能體,
在現有的游戲環境中實作 RlEnv 介面即可生成訓練所需的資料,
創建 ReplayBuffer 可以存盤并動態更新訓練資料,
在實作這些介面后,只需要呼叫 step 方法:
RlEnv.step(action, training);
這個方法會將 RlAgent 決策出的動作輸入到游戲環境中獲得反饋,我們可以在 RlEnv 中提供的 runEnviroment 方法中呼叫 step 方法,然后只需要重復執行 runEnvironment 方法,即可不斷地生成用于訓練的資料,
public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}
我們將 ReplayBuffer 可存盤的 step 數量設定為 50000,在 observe 周期我們會先向 replayBuffer 中存盤 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學習,
在 explore 和 training 周期,神經網路會隨機從 replayBuffer 中生成訓練集并將它們輸入到模型中訓練,我們使用 Adam 優化器和 MSE 損失函式迭代神經網路,
神經網路輸入預處理
首先將影像大小 resize 成 80x80 并轉為灰度圖,這有助于在不丟失資訊的情況下提高訓練速度,
public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}
然后我們把連續的四幀影像作為一個輸入,為了獲得連續四幀的連續影像,我們維護了一個全域的影像佇列保存游戲執行緒中的影像,每一次動作后替換掉最舊的一幀,然后把佇列里的影像 stack 成一個單獨的 NDArray,
public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}
一旦以上部分完成,我們就可以開始訓練了,訓練優化為了獲得最佳的訓練性能,我們關閉了 GUI 以加快樣本生成速度,并使用 Java 多執行緒將訓練回圈和樣本生成回圈分別在不同的執行緒中運行,
List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}
總結
這個模型在 NVIDIA T4 GPU 訓練了大概 4 個小時,更新了 300 萬步,訓練后的小鳥已經可以完全自主控制動作靈活穿梭于管道之間,訓練后的模型也同樣上傳到了倉庫中供您測驗,在此專案中 DJL 提供了強大的訓練 API 以及模型庫支持,使得在 Java 開發程序中得心應手,
本專案完整代碼:https://github.com/kingyuluk/RL-FlappyBird
關注后點擊“入伙”,加入我們
「閱讀原文,點個 star 吧」
CSDN認證博客專家
萌新
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/242456.html
標籤:其他
