強化學習實戰(1):gridworld


 

參考:https://orzyt.cn/posts/gridworld/

Reinforcement Learning: An Introduction》在第三章中給出了一個簡單的例子:Gridworld, 以幫助我們理解finite MDPs,

同時也求解了該問題的貝爾曼期望方程貝爾曼最優方程. 本文簡要說明如何進行編程求解.

問題

下圖用一個矩形網格展示了一個簡單finite MDP - Gridworld.
網格中的每一格對應於environment的一個state.
在每一格, 有四種可能的actions:上/下/左/右, 對應於agent往相應的方向移動一個單元格.
使agent離開網格的actions會使得agent留在原來的位置, 但是會有一個值為-1的reward.
除了那些使得agent離開state Astate B的action, 其他的actions對應的reward都是0.
處在state A時, 所有的actions會有值為+10的reward, 並且agent會移動到state A'.
處在state B時, 所有的actions會有值為+5的reward, 並且agent會移動到state B'.

 

元素

  • 狀態(State): 網格的坐標, 共 $5 \times 5 = 25$ 個狀態;
  • 動作(Action): 上/下/左/右四種動作;
  • 策略(Policy): $\pi(a | s) = \frac14 \; \text{for} \; \forall s \in S, \text{and} \; \forall \; a \in \{↑,↓,←,→\}$;
  • 獎勵(Reward): 如題所述;
  • 折扣因子(Discount rate): $\gamma \in [0, 1]$, 本文采用 $\gamma=0.9$。

目標

  • 使用貝爾曼期望方程, 求解給定隨機策略 $\pi(a | s) = \frac14$ 下的狀態值函數.
  • 使用貝爾曼最優方程, 求解最優狀態值函數.

 

實現

  1 import numpy as np
  2 
  3 %matplotlib inline
  4 import matplotlib
  5 import matplotlib.pyplot as plt
  6 from matplotlib.table import Table
  7 
  8 #定義grid問題中常用的變量
  9 # 格子尺寸
 10 WORLD_SIZE = 5
 11 # 狀態A的位置(下標從0開始,下同)
 12 A_POS = [0, 1]
 13 # 狀態A'的位置
 14 A_PRIME_POS = [4, 1]
 15 # 狀態B的位置
 16 B_POS = [0, 3]
 17 # 狀態B'的位置
 18 B_PRIME_POS = [2, 3]
 19 # 折扣因子
 20 DISCOUNT = 0.9
 21 
 22 # 動作集={上,下,左,右}
 23 ACTIONS = [np.array([-1, 0]),
 24            np.array([1, 0]),
 25            np.array([0, 1]),
 26            np.array([0, -1]),
 27 ]
 28 # 策略,每個動作等概率
 29 ACTION_PROB = 0.25
 30 
 31 
 32 #繪圖相關函數
 33 def draw_image(image):
 34     fig, ax = plt.subplots()
 35     ax.set_axis_off()
 36     tb = Table(ax, bbox=[0, 0, 1, 1])
 37 
 38     nrows, ncols = image.shape
 39     width, height = 1.0 / ncols, 1.0 / nrows
 40 
 41     # 添加表格
 42     for (i,j), val in np.ndenumerate(image):
 43         tb.add_cell(i, j, width, height, text=val, 
 44                     loc='center', facecolor='white')
 45 
 46     # 行標簽
 47     for i, label in enumerate(range(len(image))):
 48         tb.add_cell(i, -1, width, height, text=label+1, loc='right', 
 49                     edgecolor='none', facecolor='none')
 50     # 列標簽
 51     for j, label in enumerate(range(len(image))):
 52         tb.add_cell(WORLD_SIZE, j, width, height/2, text=label+1, loc='center', 
 53                            edgecolor='none', facecolor='none')
 54     ax.add_table(tb)
 55 
 56 
 57 
 58 def step(state, action):
 59     '''給定當前狀態以及采取的動作,返回后繼狀態及其立即獎勵
 60     
 61     Parameters
 62     ----------
 63     state : list
 64         當前狀態
 65     action : list
 66         采取的動作
 67     
 68     Returns
 69     -------
 70     tuple
 71         后繼狀態,立即獎勵
 72         
 73     '''
 74     # 如果當前位置為狀態A,則直接跳到狀態A',獎勵為+10
 75     if state == A_POS:
 76         return A_PRIME_POS, 10
 77     # 如果當前位置為狀態B,則直接跳到狀態B',獎勵為+5
 78     if state == B_POS:
 79         return B_PRIME_POS, 5
 80 
 81     state = np.array(state)
 82     # 通過坐標運算得到后繼狀態
 83     next_state = (state + action).tolist()
 84     x, y = next_state
 85     # 根據后繼狀態的坐標判斷是否出界
 86     if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
 87         # 出界則待在原地,獎勵為-1
 88         reward = -1.0
 89         next_state = state
 90     else:
 91         # 未出界則獎勵為0
 92         reward = 0
 93     return next_state, reward
 94 
 95 
 96 a
 97 π(a|s)[r+γ
 98 v
 99 π
100 (
101 s
102 103 )]
104 vπ=∑aπ(a|s)[r+γvπ(s′)]
105 In [5]:
106 def bellman_equation():
107     ''' 求解貝爾曼(期望)方程
108     '''
109     # 初始值函數
110     value = np.zeros((WORLD_SIZE, WORLD_SIZE))
111     while True:
112         new_value = np.zeros(value.shape)
113         # 遍歷所有狀態
114         for i in range(0, WORLD_SIZE):
115             for j in range(0, WORLD_SIZE):
116                 # 遍歷所有動作
117                 for action in ACTIONS:
118                     # 執行動作,轉移到后繼狀態,並獲得立即獎勵
119                     (next_i, next_j), reward = step([i, j], action)
120                     # 貝爾曼期望方程
121                     new_value[i, j] += ACTION_PROB * \
122                     (reward + DISCOUNT * value[next_i, next_j])
123         # 迭代終止條件: 誤差小於1e-4
124         if np.sum(np.abs(value - new_value)) < 1e-4:
125             draw_image(np.round(new_value, decimals=2))
126             plt.title('$v_{\pi}$')
127             plt.show()
128             plt.close()
129             break
130         value = new_value
131 
132 def bellman_optimal_equation():
133     '''求解貝爾曼最優方程
134     '''
135     # 初始值函數
136     value = np.zeros((WORLD_SIZE, WORLD_SIZE))
137     while True:
138         new_value = np.zeros(value.shape)
139         # 遍歷所有狀態
140         for i in range(0, WORLD_SIZE):
141             for j in range(0, WORLD_SIZE):
142                 values = []
143                 # 遍歷所有動作
144                 for action in ACTIONS:
145                     # 執行動作,轉移到后繼狀態,並獲得立即獎勵
146                     (next_i, next_j), reward = step([i, j], action)
147                     # 緩存動作值函數 q(s,a) = r + γ*v(s')
148                     values.append(reward + DISCOUNT * value[next_i, next_j])
149                 # 根據貝爾曼最優方程,找出最大的動作值函數 q(s,a) 進行更新
150                 new_value[i, j] = np.max(values)
151         # 迭代終止條件: 誤差小於1e-4
152         if np.sum(np.abs(new_value - value)) < 1e-4:
153             draw_image(np.round(new_value, decimals=2))
154             plt.title('$v_{*}$')
155             plt.show()
156             plt.close()
157             break
158         value = new_value
159 
160 
161 bellman_equation()
162 
163 bellman_optimal_equation()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM