組合游戲1: 詳解Minimax 和 Alpha Beta剪枝算法


本系列,我們來看看在一種常見的組合游戲——回合制棋盤類游戲中,如何用算法來解決問題。首先,我們會介紹並解決搜索空間較小的問題,引入經典的博弈算法和相關理論,最終實現在大搜索空間中的Deep RL近似算法。在此基礎上可以理解AlphaGo的原理和工作方式。本系列的第一篇,我們介紹3個Leetcode中的零和回合制游戲,從最初的暴力解法,到動態規划最終演變成博弈論里的經典算法:minimax 以及 alpha beta 剪枝。

獲得最好的閱讀體驗,請點擊最下方 閱讀原文,並在電腦上打開

  • 第一篇 [Leetcode中的Minimax 和 Alpha Beta剪枝]

  • 第二篇: 一些組合游戲的理論

  • 第三篇: 連接N個點 的OpenAI Gym GUI環境

  • 第四篇: 蒙特卡洛樹搜索(MCTS)和時間差分學習(TD learning)

Leetcode 292 Nim Game (簡單)

簡單題 Leetcode 292 Nim Game。

你和你的朋友,兩個人一起玩 Nim游戲:桌子上有一堆石頭,每次你們輪流拿掉 1 - 3 塊石頭。拿掉最后一塊石頭的人就是獲勝者。你作為先手。
你們是聰明人,每一步都是最優解。編寫一個函數,來判斷你是否可以在給定石頭數量的情況下贏得游戲。

示例:
輸入: 4
輸出: false
解釋: 如果堆中有 4 塊石頭,那么你永遠不會贏得比賽;因為無論你拿走 1 塊、2 塊 還是 3 塊石頭,最后一塊石頭總是會被你的朋友拿走。

定義  為有個石頭並采取最優策略的游戲結果, 的值只有可能是贏或者輸。考察前幾個結果:,然后來計算。因為玩家采取最優策略(只要有一種走法讓對方必輸,玩家獲勝),對於4來說,玩家能走的可能是拿掉1塊、2塊或3塊,但是無論剩余何種局面,對方都是必贏,因此,4就是必輸。總的說來,遞歸關系如下:

這個遞歸式可以直接翻譯成Python 3代碼

# TLE
# Time Complexity: O(exponential)
class Solution_BruteForce:

    def canWinNim(self, n: int) -> bool:
        if n <= 3:
            return True
        for i in range(1, 4):
            if not self.canWinNim(n - i):
                return True
        return False

以上的遞歸公式和代碼很像fibonacci數的遞歸定義和暴力解法,因此對應的時間復雜度也是指數級的,提交代碼以后會TLE。下圖畫出了當n=7時的遞歸調用,注意 5 被擴展向下重復執行了兩次,4重復了4次。

 

 

292 Nim Game 暴力解法調用圖  n=7

我們采用和fibonacci一樣的方式來優化算法:緩存較小n的結果以此來計算較大n的結果。Python 中,我們可以只加一行lru_cache decorator,來取得這種動態規划效果,下面的代碼將復雜度降到了 。

# RecursionError: maximum recursion depth exceeded in comparison n=1348820612
# Time Complexity: O(N)
class Solution_DP:
    from functools import lru_cache
    @lru_cache(maxsize=None)
    def canWinNim(self, n: int) -> bool:
        if n <= 3:
            return True
        for i in range(1, 4):
            if not self.canWinNim(n - i):
                return True
        return False

再來畫出調用圖:這次5和4就不再被展開重復計算,圖中綠色的節點表示緩存命中。

 292 Nim Game 動歸解法調用圖 n=7但還是沒有AC,因為當n=1348820612時,這種方式會導致棧溢出。再改成下面的循環版本,可惜還是TLE。

# TLE for 1348820612
# Time Complexity: O(N)
class Solution:
    def canWinNim(self, n: int) -> bool:
        if n <= 3:
            return True
        last3, last2, last1 = True, True, True
        for i in range(4, n+1):
            this = not (last3 and last2 and last1)
            last3, last2, last1 = last2, last1, this
        return last1

由此看來,AC 版本需要低於的算法復雜度。上面的寫法似乎暗示輸贏有周期性的規律。事實上,如果將輸贏按照順序畫出來,就馬上得出規律了:只要 就是輸,否則贏。原因如下:當面臨不能被4整除的數量時  ,一方總是可以拿走  個,將 留給對手,而對方下輪又將返回不能被4整除的數,如此循環往復,直到這一方有1, 2, 3 個,最終獲勝。

 輸贏分布

最終AC版本,只有一句語句。

AC
Time Complexity: O(1)
class Solution:
    def canWinNim(self, n: int) -> bool:
        return not (n % 4 == )

Leetcode 486 Predict the Winner (中等)

中等難度題目:Leetcode 486 Predict the Winner.

給定一個表示分數的非負整數數組。玩家1從數組任意一端拿取一個分數,隨后玩家2繼續從剩余數組任意一端拿取分數,然后玩家1拿,……。每次一個玩家只能拿取一個分數,分數被拿取之后不再可取。直到沒有剩余分數可取時游戲結束。最終獲得分數總和最多的玩家獲勝。
給定一個表示分數的數組,預測玩家1是否會成為贏家。你可以假設每個玩家的玩法都會使他的分數最大化。

示例 1:
輸入: [1, 5, 2]
輸出: False
解釋: 一開始,玩家1可以從1和2中進行選擇。
如果他選擇2(或者1),那么玩家2可以從1(或者2)和5中進行選擇。如果玩家2選擇了5,那么玩家1則只剩下1(或者2)可選。
所以,玩家1的最終分數為 1 + 2 = 3,而玩家2為 5。
因此,玩家1永遠不會成為贏家,返回 False。

示例 2:
輸入: [1, 5, 233, 7]
輸出: True
解釋: 玩家1一開始選擇1。然后玩家2必須從5和7中進行選擇。無論玩家2選擇了哪個,玩家1都可以選擇233。
最終,玩家1(234分)比玩家2(12分)獲得更多的分數,所以返回 True,表示玩家1可以成為贏家。

對於當前玩家,他有兩種選擇:左邊或者右邊的數。定義 maxDiff(l, r) 為剩余子數組時,當前玩家能取得的最大分差,那么

對應的時間復雜度可以寫出遞歸式,顯然是指數級的:

采用暴力解法可以AC,但運算時間很長,接近TLE邊緣 (6300ms)。

# AC
# Time Complexity: O(2^N)
# Slow: 6300ms
from typing import List

class Solution:

    def maxDiff(self, l: int, r:int) -> int:
        if l == r:
            return self.nums[l]
        return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

    def PredictTheWinner(self, nums: List[int]) -> bool:
        self.nums = nums
        return self.maxDiff(, len(nums) - 1) >= 

從調用圖也很容易看出是指數級的復雜度

 486 Predict the Winner 暴力解法調用圖 n=4

上圖中我們有重復計算的節點,例如[1-2]節點被計算了兩次。使用 lru_cache 大法,在maxDiff 上僅加了一句,就能以復雜度 和運行時間 43ms  AC。

# AC
# Time Complexity: O(N^2)
# Fast: 43ms
from functools import lru_cache
from typing import List

class Solution:

    @lru_cache(maxsize=None)
    def maxDiff(self, l: int, r:int) -> int:
        if l == r:
            return self.nums[l]
        return max(self.nums[l] - self.maxDiff(l + 1, r), self.nums[r] - self.maxDiff(l, r - 1))

    def PredictTheWinner(self, nums: List[int]) -> bool:
        self.nums = nums
        return self.maxDiff(, len(nums) - 1) >= 

動態規划解法調用圖可以看出節點 [1-2] 這次沒有被計算兩次。

 486 Predict the Winner 動歸解法調用圖 n=4

Leetcode 464 Can I Win (中等)

類似但稍有難度的題目 Leetcode 464 Can I Win。難點在於使用了位的狀態壓縮。

在 "100 game" 這個游戲中,兩名玩家輪流選擇從 1 到 10 的任意整數,累計整數和,先使得累計整數和達到 100 的玩家,即為勝者。
如果我們將游戲規則改為 “玩家不能重復使用整數” 呢?
例如,兩個玩家可以輪流從公共整數池中抽取從 1 到 15 的整數(不放回),直到累計整數和 >= 100。
給定一個整數 maxChoosableInteger (整數池中可選擇的最大數)和另一個整數 desiredTotal(累計和),判斷先出手的玩家是否能穩贏(假設兩位玩家游戲時都表現最佳)?
你可以假設 maxChoosableInteger 不會大於 20, desiredTotal 不會大於 300。

示例:
輸入:
maxChoosableInteger = 10
desiredTotal = 11
輸出:
false
解釋:
無論第一個玩家選擇哪個整數,他都會失敗。
第一個玩家可以選擇從 1 到 10 的整數。
如果第一個玩家選擇 1,那么第二個玩家只能選擇從 2 到 10 的整數。
第二個玩家可以通過選擇整數 10(那么累積和為 11 >= desiredTotal),從而取得勝利.
同樣地,第一個玩家選擇任意其他整數,第二個玩家都會贏。

# AC
# Time Complexity: O:(2^m*m), m: maxChoosableInteger
class Solution:
    from functools import lru_cache
    @lru_cache(maxsize=None)
    def recurse(self, status: int, currentTotal: int) -> bool:
        for i in range(1, self.maxChoosableInteger + 1):
            if not (status >> i & 1):
                new_status = 1 << i | status
                if currentTotal + i >= self.desiredTotal:
                    return True
                if not self.recurse(new_status, currentTotal + i):
                    return True
        return False


    def canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:
        self.maxChoosableInteger = maxChoosableInteger
        self.desiredTotal = desiredTotal

        sum = maxChoosableInteger * (maxChoosableInteger + 1) / 2
        if sum < desiredTotal:
            return False
        return self.recurse(, )

上面的代碼算法復雜度為,m是maxChoosableInteger。由於所有狀態的數量是,對於每個狀態,最多會嘗試  走法。

Minimax 算法

至此,我們AC了leetcode中的幾道零和回合制博弈游戲。事實上,在這個領域有通用的算法:回合制博弈下的minimax。算法背景如下,兩個玩家輪流玩,第一個玩家max的目的是將游戲的效用最大化,第二個玩家min則是最小化效用。比如,下面的節點表示玩家選取節點后游戲的效用,當兩個玩家都能采取最優策略,Minimax 算法從底層節點來計算,游戲的結果是最終max 玩家會得到-7。

 Wikipedia Minimax 例子

Minimax Python 3偽代碼如下。

def minimax(node: Node, depth: int, maximizingPlayer: bool) -> int:
    if depth ==  or is_terminal(node):
        return evaluate_terminal(node)
    if maximizingPlayer:
        value:int = −∞
        for child in node:
            value = max(value, minimax(child, depth − 1, False))
        return value
    else: # minimizing player
        value := +∞
        for child in node:
            value = min(value, minimax(child, depth − 1True))
        return value

Minimax: 486 Predict the Winner

我們知道486 Predict the Winner 是有minimax解法的,但如何具體實現,其難點在於如何定義合適的游戲價值或者效用。之前的手機號碼購買平台地圖解法中,我們定義maxDiff(l, r) 來表示當前玩家面臨子區間  時能取得的最大分差。對於minimax算法,max 玩家要最大化游戲價值,min玩家要最小化游戲價值。先考慮最簡單情況即只有一個數x時,若定義max玩家在此局面下得到這個數時游戲價值為 +x,則min玩家為-x,即max玩家得到的所有數為正(),min玩家得到的所有數為負()。至此,max玩家的目標就是  ,min玩家是 。有了精確的定義和優化目標,代碼只需要套一下上面的模版。

# AC
from functools import lru_cache
from typing import List

class Solution:
    # max_player: max(A - B)
    # min_player: min(A - B)
    @lru_cache(maxsize=None)
    def minimax(self, l: int, r: int, isMaxPlayer: bool) -> int:
        if l == r:
            return self.nums[l] * (if isMaxPlayer else -1)

        if isMaxPlayer:
            return max(
                self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
                self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))
        else:
            return min(
                -self.nums[l] + self.minimax(l + 1, r, not isMaxPlayer),
                -self.nums[r] + self.minimax(l, r - 1, not isMaxPlayer))

    def PredictTheWinner(self, nums: List[int]) -> bool:
        self.nums = nums
        v = self.minimax(, len(nums) - 1, True)
        return v >= 

 Minimax 486 調用圖 nums=[1, 5, 2, 4]

Minimax: 464 Can I Win

該題目是很典型的此類游戲,即結果為贏輸平,但是中間的狀態沒有直接對應的游戲價值。對於這樣的問題,一般定義為,max玩家勝,價值 +1,min玩家勝,價值-1,平則0。下面的AC代碼實現了 Minimax 算法。算法中針對兩個玩家都有剪枝(沒有剪枝無法AC)。具體來說,max玩家一旦在某一節點取得勝利(value=1),就停止繼續向下搜索,因為這是他能取得的最好分數。同理,min玩家一旦取得-1也直接返回上層節點。這個剪枝可以泛化成 alpha beta剪枝算法。

# AC
class Solution:
    from functools import lru_cache
    @lru_cache(maxsize=None)
    # currentTotal < desiredTotal
    def minimax(self, status: int, currentTotal: int, isMaxPlayer: bool) -> int:
        import math
        if status == self.allUsed:
            return   # draw: no winner

        if isMaxPlayer:
            value = -math.inf
            for i in range(1, self.maxChoosableInteger + 1):
                if not (status >> i & 1):
                    new_status = 1 << i | status
                    if currentTotal + i >= self.desiredTotal:
                        return 1  # shortcut
                    value = max(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
                    if value == 1:
                        return 1
            return value
        else:
            value = math.inf
            for i in range(1, self.maxChoosableInteger + 1):
                if not (status >> i & 1):
                    new_status = 1 << i | status
                    if currentTotal + i >= self.desiredTotal:
                        return -1  # shortcut
                    value = min(value, self.minimax(new_status, currentTotal + i, not isMaxPlayer))
                    if value == -1:
                        return -1
            return value

Alpha-Beta 剪枝

在464 Can I Win minimax 算法代碼實現中,我們發現有剪枝優化空間。對於每個節點,定義兩個值alpha 和 beta,表示從根節點到目前局面時,max玩家保證能取得的最小值以及min玩家能保證取得的最大值。初始時,根節點alpha = −∞ , beta = +∞,表示游戲最終的價值在區間 [−∞, +∞]中。在向下遍歷的過程中,子節點先繼承父節點的 alpha beta 值進而繼承區間 [alpha, beta]。當子節點在向下遍歷的時候同步更新alpha 或者 beta,一旦區間[alpha, beta]非法就立即向上返回。舉個Wikimedia的例子來進一步說明:

  1. 根節點初始時:alpha = −∞, beta = +∞

  2. 根節點,最左邊子節點返回4后:alpha = 4, beta = +∞

  3. 根節點,中間子節點返回5后:alpha = 5, beta = +∞

  4. 最右Min節點(標1節點),初始時:alpha = 5, beta = +∞

  5. 最右Min節點(標1節點),第一個子節點返回1后:alpha = 5, beta = 1

此時,最右Min節點的alpha, beta形成了無效區間[5, 1],滿足了剪枝條件,因此可以不用計算它的第二個和第三個子節點。如果剩余子節點返回值 > 1,比如2,由於這是個min節點,將會被已經到手的1替換。若其他子節點返回值  < 1,但由於min的父節點有效區間是[5,  +∞],已經保證了>=5,小於5的值也會被忽略。

 Wikimedia Alpha Beta 剪枝例子Minimax Python 3偽代碼如下

def alpha_beta(node: Node, depth: int, α: int, β: int, maximizingPlayer: bool) -> int:
    if depth ==  or is_terminal(node):
        return evaluate_terminal(node)
    if maximizingPlayer:
        value: int = −∞
        for child in node:
            value = max(value, alphabeta(child, depth − 1, α, β, False))
            α = max(α, value)
            if α >= β:
                break # β cut-off
        return value
    else:
        value: int = +∞
        for child in node:
            value = min(value, alphabeta(child, depth − 1, α, β, True))
            β = min(β, value)
            if β <= α:
                break # α cut-off
        return value

Alpha-Beta Pruning: 486 Predict the Winner

# AC
import math
from functools import lru_cache
from typing import List

class Solution:
    def alpha_beta(self, l: int, r: int, curr: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
        if l == r:
            return curr + self.nums[l] * (if isMaxPlayer else -1)

        if isMaxPlayer:
            ret = self.alpha_beta(l + 1, r, curr + self.nums[l], not isMaxPlayer, alpha, beta)
            alpha = max(alpha, ret)
            if alpha >= beta:
                return alpha
            ret = max(ret, self.alpha_beta(l, r - 1, curr + self.nums[r], not isMaxPlayer, alpha, beta))
            return ret
        else:
            ret = self.alpha_beta(l + 1, r, curr - self.nums[l], not isMaxPlayer, alpha, beta)
            beta = min(beta, ret)
            if alpha >= beta:
                return beta
            ret = min(ret, self.alpha_beta(l, r - 1, curr - self.nums[r], not isMaxPlayer, alpha, beta))
            return ret

    def PredictTheWinner(self, nums: List[int]) -> bool:
        self.nums = nums
        v = self.alpha_beta(, len(nums) - 1, , True, -math.inf, math.inf)
        return v >= 

Alpha-Beta Pruning: 464 Can I Win

# AC
class Solution:
    from functools import lru_cache
    @lru_cache(maxsize=None)
    # currentTotal < desiredTotal
    def alpha_beta(self, status: int, currentTotal: int, isMaxPlayer: bool, alpha: int, beta: int) -> int:
        import math
        if status == self.allUsed:
            return   # draw: no winner

        if isMaxPlayer:
            value = -math.inf
            for i in range(1, self.maxChoosableInteger + 1):
                if not (status >> i & 1):
                    new_status = 1 << i | status
                    if currentTotal + i >= self.desiredTotal:
                        return 1  # shortcut
                    value = max(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
                    alpha = max(alpha, value)
                    if alpha >= beta:
                        return value
            return value
        else:
            value = math.inf
            for i in range(1, self.maxChoosableInteger + 1):
                if not (status >> i & 1):
                    new_status = 1 << i | status
                    if currentTotal + i >= self.desiredTotal:
                        return -1  # shortcut
                    value = min(value, self.alpha_beta(new_status, currentTotal + i, not isMaxPlayer, alpha, beta))
                    beta = min(beta, value)
                    if alpha >= beta:
                        return value
            return value

C++, Java, Javascript AC 486 Predict the Winner

最后介紹一種不同的DP實現:用C++, Java, Javascript 實現自底向上的DP解法來AC leetcode 486,當然其他語言沒有Python的lru_cache大法。以下實現中,注意DP解的構建順序,先解決小規模的問題,並在此基礎上計算稍大的問題。值得一提的是,以下的循環寫法嚴格保證了  次循環,但是自頂向下的計划遞歸可能會少於 次循環。

Java AC Code

// AC
class Solution {
    public boolean PredictTheWinner(int[] nums) {
        int n = nums.length;
        int[][] dp = new int[n][n];
        for (int i = ; i ni++) {
            dp[i][i] = nums[i];
        }

        for (int l = n - 1l >= ; l--) {
            for (int r = l + 1; r nr++) {
                dp[l][r] = Math.max(
                        nums[l- dp[l + 1][r],
                        nums[r- dp[l][r - 1]);
            }
        }
        return dp[][n - 1] >= ;
    }
}

C++ AC Code

// AC
class Solution {
public:
    bool PredictTheWinner(vector<int>& nums) {
        int n = nums.size();
        vector<vector<int>> dp(n, vector<int>(n, ));
        for (int i = ; i < n; i++) {
          dp[i][i] = nums[i];
        }
        for (int l = n - 1; l >= ; l--) {
            for (int r = l + 1; r < n; r++) {
                dp[l][r] = max(nums[l] - dp[l + 1][r], nums[r] - dp[l][r - 1]);
            }
        }
        return dp[][n - 1] >= ;
    }
};

Javascript AC Code

/**
 * @param {number[]} nums
 * @return {boolean}
 */
var PredictTheWinner = function(nums) {
    const n = nums.length;
    const dp = new Array(n).fill().map(() => new Array(n));

    for (let i = ; i < n; i++) {
      dp[i][i] = nums[i];
    }
  
    for (let l = n - 1; l >=; l--) {
        for (let r = i + 1; r < n; r++) {
            dp[l][r] = Math.max(nums[l] - dp[l + 1][r],nums[r] - dp[l][r - 1]);
        }
    }
  
    return dp[][n-1] >=;
};

著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM