參考:https://orzyt.cn/posts/gridworld/
Reinforcement Learning: An Introduction》在第三章中給出了一個簡單的例子:Gridworld
, 以幫助我們理解finite MDPs
,
同時也求解了該問題的貝爾曼期望方程和貝爾曼最優方程. 本文簡要說明如何進行編程求解.
問題
下圖用一個矩形網格展示了一個簡單finite MDP - Gridworld
.
網格中的每一格對應於environment的一個state.
在每一格, 有四種可能的actions:上/下/左/右
, 對應於agent往相應的方向移動一個單元格.
使agent離開網格的actions會使得agent留在原來的位置, 但是會有一個值為-1
的reward.
除了那些使得agent離開state A
和state B
的action, 其他的actions對應的reward都是0
.
處在state A
時, 所有的actions會有值為+10
的reward, 並且agent會移動到state A'
.
處在state B
時, 所有的actions會有值為+5
的reward, 並且agent會移動到state B'
.
元素
- 狀態(State): 網格的坐標, 共 $5 \times 5 = 25$ 個狀態;
- 動作(Action):
上/下/左/右
四種動作; - 策略(Policy): $\pi(a | s) = \frac14 \; \text{for} \; \forall s \in S, \text{and} \; \forall \; a \in \{↑,↓,←,→\}$;
- 獎勵(Reward): 如題所述;
- 折扣因子(Discount rate): $\gamma \in [0, 1]$, 本文采用 $\gamma=0.9$。
目標
- 使用貝爾曼期望方程, 求解給定隨機策略 $\pi(a | s) = \frac14$ 下的狀態值函數.
- 使用貝爾曼最優方程, 求解最優狀態值函數.
實現
1 import numpy as np 2 3 %matplotlib inline 4 import matplotlib 5 import matplotlib.pyplot as plt 6 from matplotlib.table import Table 7 8 #定義grid問題中常用的變量 9 # 格子尺寸 10 WORLD_SIZE = 5 11 # 狀態A的位置(下標從0開始,下同) 12 A_POS = [0, 1] 13 # 狀態A'的位置 14 A_PRIME_POS = [4, 1] 15 # 狀態B的位置 16 B_POS = [0, 3] 17 # 狀態B'的位置 18 B_PRIME_POS = [2, 3] 19 # 折扣因子 20 DISCOUNT = 0.9 21 22 # 動作集={上,下,左,右} 23 ACTIONS = [np.array([-1, 0]), 24 np.array([1, 0]), 25 np.array([0, 1]), 26 np.array([0, -1]), 27 ] 28 # 策略,每個動作等概率 29 ACTION_PROB = 0.25 30 31 32 #繪圖相關函數 33 def draw_image(image): 34 fig, ax = plt.subplots() 35 ax.set_axis_off() 36 tb = Table(ax, bbox=[0, 0, 1, 1]) 37 38 nrows, ncols = image.shape 39 width, height = 1.0 / ncols, 1.0 / nrows 40 41 # 添加表格 42 for (i,j), val in np.ndenumerate(image): 43 tb.add_cell(i, j, width, height, text=val, 44 loc='center', facecolor='white') 45 46 # 行標簽 47 for i, label in enumerate(range(len(image))): 48 tb.add_cell(i, -1, width, height, text=label+1, loc='right', 49 edgecolor='none', facecolor='none') 50 # 列標簽 51 for j, label in enumerate(range(len(image))): 52 tb.add_cell(WORLD_SIZE, j, width, height/2, text=label+1, loc='center', 53 edgecolor='none', facecolor='none') 54 ax.add_table(tb) 55 56 57 58 def step(state, action): 59 '''給定當前狀態以及采取的動作,返回后繼狀態及其立即獎勵 60 61 Parameters 62 ---------- 63 state : list 64 當前狀態 65 action : list 66 采取的動作 67 68 Returns 69 ------- 70 tuple 71 后繼狀態,立即獎勵 72 73 ''' 74 # 如果當前位置為狀態A,則直接跳到狀態A',獎勵為+10 75 if state == A_POS: 76 return A_PRIME_POS, 10 77 # 如果當前位置為狀態B,則直接跳到狀態B',獎勵為+5 78 if state == B_POS: 79 return B_PRIME_POS, 5 80 81 state = np.array(state) 82 # 通過坐標運算得到后繼狀態 83 next_state = (state + action).tolist() 84 x, y = next_state 85 # 根據后繼狀態的坐標判斷是否出界 86 if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE: 87 # 出界則待在原地,獎勵為-1 88 reward = -1.0 89 next_state = state 90 else: 91 # 未出界則獎勵為0 92 reward = 0 93 return next_state, reward 94 95 96 a 97 π(a|s)[r+γ 98 v 99 π 100 ( 101 s 102 ′ 103 )] 104 vπ=∑aπ(a|s)[r+γvπ(s′)] 105 In [5]: 106 def bellman_equation(): 107 ''' 求解貝爾曼(期望)方程 108 ''' 109 # 初始值函數 110 value = np.zeros((WORLD_SIZE, WORLD_SIZE)) 111 while True: 112 new_value = np.zeros(value.shape) 113 # 遍歷所有狀態 114 for i in range(0, WORLD_SIZE): 115 for j in range(0, WORLD_SIZE): 116 # 遍歷所有動作 117 for action in ACTIONS: 118 # 執行動作,轉移到后繼狀態,並獲得立即獎勵 119 (next_i, next_j), reward = step([i, j], action) 120 # 貝爾曼期望方程 121 new_value[i, j] += ACTION_PROB * \ 122 (reward + DISCOUNT * value[next_i, next_j]) 123 # 迭代終止條件: 誤差小於1e-4 124 if np.sum(np.abs(value - new_value)) < 1e-4: 125 draw_image(np.round(new_value, decimals=2)) 126 plt.title('$v_{\pi}$') 127 plt.show() 128 plt.close() 129 break 130 value = new_value 131 132 def bellman_optimal_equation(): 133 '''求解貝爾曼最優方程 134 ''' 135 # 初始值函數 136 value = np.zeros((WORLD_SIZE, WORLD_SIZE)) 137 while True: 138 new_value = np.zeros(value.shape) 139 # 遍歷所有狀態 140 for i in range(0, WORLD_SIZE): 141 for j in range(0, WORLD_SIZE): 142 values = [] 143 # 遍歷所有動作 144 for action in ACTIONS: 145 # 執行動作,轉移到后繼狀態,並獲得立即獎勵 146 (next_i, next_j), reward = step([i, j], action) 147 # 緩存動作值函數 q(s,a) = r + γ*v(s') 148 values.append(reward + DISCOUNT * value[next_i, next_j]) 149 # 根據貝爾曼最優方程,找出最大的動作值函數 q(s,a) 進行更新 150 new_value[i, j] = np.max(values) 151 # 迭代終止條件: 誤差小於1e-4 152 if np.sum(np.abs(new_value - value)) < 1e-4: 153 draw_image(np.round(new_value, decimals=2)) 154 plt.title('$v_{*}$') 155 plt.show() 156 plt.close() 157 break 158 value = new_value 159 160 161 bellman_equation() 162 163 bellman_optimal_equation()