用 Java 訓練出一只“不死鳥”


作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手機游戲,因其簡單的玩法但極度困難的設定迅速走紅全網。隨着深度學習(DL)與增強學習(RL)等前沿算法的發展,我們可以使用 Java 非常方便地訓練出一個智能體來控制 Flappy Bird。

故事開始於《GitHub 上的大佬們打完招呼,會聊些什么?》,那么,今天我們就來一起看一下如何用 Java 訓練出一個不死鳥。游戲項目我們使用了一個僅用 Java 基本類庫編寫的 FlappyBird 游戲。在訓練方面,我們使用 DeepJavaLibrary 一個基於 Java 的深度學習框架來構建增強學習訓練網絡並進行訓練。經過了差不多 300 萬步(四小時)的訓練后,小鳥已經可以獲得最高 8000 多分的成績,靈活穿梭於水管之間。

在本文中,我們將從原理開始一步一步實現增強學習算法並用它對游戲進行訓練。如果任何一個時刻不清楚如何繼續進行下去,可以參閱項目的源碼。

項目地址:https://github.com/kingyuluk/RL-FlappyBird

增強學習(RL)的架構

在這一節會介紹主要用到的算法以及神經網絡,幫助你更好的了解如何進行訓練。本項目與 DeepLearningFlappyBird 使用了類似的方法進行訓練。算法整體的架構是 Q-Learning + 卷積神經網絡(CNN),把游戲每一幀的狀態存儲起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經網絡的訓練數據。

CNN 訓練簡述

CNN 的輸入數據為連續的 4 幀圖像,我們將這圖像 stack 起來作為小鳥當前的“observation”,圖像會轉換成灰度圖以減少所需的訓練資源。圖像存儲的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 數組里的元素就是當前幀的像素值,這些數據將輸入到 CNN 后將輸出 (batch size, 2) 的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應的收益。

訓練數據

在小鳥采取動作后,我們會得到 preObservation and currentObservation 即是兩組 4 幀的連續的圖像表示小鳥動作前和動作后的狀態。然后我們將 preObservation, currentObservation, action, reward, terminal 組成的五元組作為一個 step 存進 replayBuffer 中。它是一個有限大小的訓練數據集,他會隨着最新的操作動態更新內容。

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

訓練的三個周期

訓練分為 3 個不同的周期以更好地生成訓練數據:

  • Observe(觀察) 周期:隨機產生訓練數據
  • Explore (探索) 周期:隨機與推理動作結合更新訓練數據
  • Training (訓練) 周期:推理動作主導產生新數據

通過這種訓練模式,我們可以更好的達到預期效果。

處於 Explore 周期時,我們會根據權重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作。訓練前期,隨機動作的權重會非常大,因為模型的決策十分不准確 (甚至不如隨機)。在訓練后期時,隨着模型學習的動作逐步增加,我們會不斷增加模型推理動作的權重並最終使它成為主導動作。調節隨機動作的參數叫做 epsilon 它會隨着訓練的過程不斷變化。

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

訓練邏輯

首先,我們會從 replayBuffer 中隨機抽取一批數據作為作為訓練集。然后將 preObservation 輸入到神經網絡得到所有行為的 reward(Q)作為預測值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

postObservation 同樣會輸入到神經網絡,根據馬爾科夫決策過程以及貝爾曼價值函數計算出所有行為的 reward(targetQ)作為真實值:

// 將 postInput 輸入到神經網絡中得到 targetQReward 是 (batchsize,2) 的矩陣。根據 Q-learning 的算法,每一次的 targetQ 需要根據當前環境是否結束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在訓練結束時,計算 Q 和 targetQ 的損失值,並在 CNN 中更新權重。

卷積神經網絡模型(CNN)

我們采用了采用了 3 個卷積層,4 個 relu 激活函數以及 2 個全連接層的神經網絡架構。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

訓練過程

DJL 的 RL 庫中提供了非常方便的用於實現強化學習的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 實現 RlAgent 接口即可構建一個可以進行訓練的智能體。
  • 在現有的游戲環境中實現 RlEnv 接口即可生成訓練所需的數據。
  • 創建 ReplayBuffer 可以存儲並動態更新訓練數據。

在實現這些接口后,只需要調用 step 方法:

RlEnv.step(action, training);

這個方法會將 RlAgent 決策出的動作輸入到游戲環境中獲得反饋。我們可以在 RlEnv 中提供的 runEnviroment 方法中調用 step 方法,然后只需要重復執行 runEnvironment 方法,即可不斷地生成用於訓練的數據。

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

我們將 ReplayBuffer 可存儲的 step 數量設置為 50000,在 observe 周期我們會先向 replayBuffer 中存儲 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學習。

在 explore 和 training 周期,神經網絡會隨機從 replayBuffer 中生成訓練集並將它們輸入到模型中訓練。我們使用 Adam 優化器和 MSE 損失函數迭代神經網絡。

神經網絡輸入預處理

首先將圖像大小 resize 成 80x80 並轉為灰度圖,這有助於在不丟失信息的情況下提高訓練速度。

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

然后我們把連續的四幀圖像作為一個輸入,為了獲得連續四幀的連續圖像,我們維護了一個全局的圖像隊列保存游戲線程中的圖像,每一次動作后替換掉最舊的一幀,然后把隊列里的圖像 stack 成一個單獨的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

一旦以上部分完成,我們就可以開始訓練了。訓練優化為了獲得最佳的訓練性能,我們關閉了 GUI 以加快樣本生成速度。並使用 Java 多線程將訓練循環和樣本生成循環分別在不同的線程中運行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

總結

這個模型在 NVIDIA T4 GPU 訓練了大概 4 個小時,更新了 300 萬步。訓練后的小鳥已經可以完全自主控制動作靈活穿梭與管道之間。訓練后的模型也同樣上傳到了倉庫中供您測試。在此項目中 DJL 提供了強大的訓練 API 以及模型庫支持,使得在 Java 開發過程中得心應手。

本項目完整代碼:https://github.com/kingyuluk/RL-FlappyBird


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM