【強化學習】python 實現 q-learning 例五(GUI)


本文作者:hhh5460

本文地址:https://www.cnblogs.com/hhh5460/p/10143579.html

感謝pengdali,本文的 class Maze 參考了他的博客,地址:https://blog.csdn.net/pengdali/article/details/79369966

0.問題情境

一個6*6的迷宮,左上角入口,右下角出口。紅色矩形為玩家,黑色矩形為陷阱,黃色矩形為元寶。如圖

1.問題分析

將二維問題轉化為一維問題,並解決之。

狀態集:[0,1,..,35],0~35,共36個。

動作集:['u', 'd', 'l', 'r'],上下左右,共4個。

獎勵集:[0,-10,0,0,...,10],每個位置一個獎勵值,共36個。其中空白位置0,陷阱-10,元寶3,出口10(隨便)

那么,接下來需要從上三個集合中選擇相應的元素或子集,用一維的方式比二維的方式更簡便。其方法與前面的例子是類似的,只需要改變首末兩行、兩列的值即可!這里不細說了。

 

2.完整代碼

import pandas as pd
import random
import time
import pickle
import pathlib
import os
import tkinter as tk

'''
 6*6 的迷宮:
-------------------------------------------
| 入口 | 陷阱 |      |      |      |      |
-------------------------------------------
|      | 陷阱 |      |      | 陷阱 |      |
-------------------------------------------
|      | 陷阱 |      | 陷阱 |      |      |
-------------------------------------------
|      | 陷阱 |      | 陷阱 |      |      |
-------------------------------------------
|      | 陷阱 |      | 陷阱 | 元寶 |      |
-------------------------------------------
|      |      |      | 陷阱 |      | 出口 |
-------------------------------------------

作者:hhh5460
時間:20181219
地點:Tai Zi Miao
'''


class Maze(tk.Tk):
    '''環境類(GUI)'''
    UNIT = 40  # pixels
    MAZE_H = 6  # grid height
    MAZE_W = 6  # grid width
 
    def __init__(self):
        '''初始化'''
        super().__init__()
        self.title('迷宮')
        h = self.MAZE_H * self.UNIT
        w = self.MAZE_W * self.UNIT
        self.geometry('{0}x{1}'.format(h, w)) #窗口大小
        self.canvas = tk.Canvas(self, bg='white', height=h, width=w)
        # 畫網格
        for c in range(0, w, self.UNIT):
            self.canvas.create_line(c, 0, c, h)
        for r in range(0, h, self.UNIT):
            self.canvas.create_line(0, r, w, r)
        # 畫陷阱
        self._draw_rect(1, 0, 'black')
        self._draw_rect(1, 1, 'black')
        self._draw_rect(1, 2, 'black')
        self._draw_rect(1, 3, 'black')
        self._draw_rect(1, 4, 'black')
        self._draw_rect(3, 2, 'black')
        self._draw_rect(3, 3, 'black')
        self._draw_rect(3, 4, 'black')
        self._draw_rect(3, 5, 'black')
        self._draw_rect(4, 1, 'black')
        # 畫獎勵
        self._draw_rect(4, 4, 'yellow')
        # 畫玩家(保存!!)
        self.rect = self._draw_rect(0, 0, 'red')
        self.canvas.pack() # 顯示畫作!
        
    def _draw_rect(self, x, y, color):
        '''畫矩形,  x,y表示橫,豎第幾個格子'''
        padding = 5 # 內邊距5px,參見CSS
        coor = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding]
        return self.canvas.create_rectangle(*coor, fill = color)
 
    def move_to(self, state, delay=0.01):
        '''玩家移動到新位置,根據傳入的狀態'''
        coor_old = self.canvas.coords(self.rect) # 形如[5.0, 5.0, 35.0, 35.0](第一個格子左上、右下坐標)
        x, y = state % 6, state // 6 #橫豎第幾個格子
        padding = 5 # 內邊距5px,參見CSS
        coor_new = [self.UNIT * x + padding, self.UNIT * y + padding, self.UNIT * (x+1) - padding, self.UNIT * (y+1) - padding]
        dx_pixels, dy_pixels = coor_new[0] - coor_old[0], coor_new[1] - coor_old[1] # 左上角頂點坐標之差
        self.canvas.move(self.rect, dx_pixels, dy_pixels)
        self.update() # tkinter內置的update!
        time.sleep(delay)


class Agent(object):
    '''個體類'''
    def __init__(self, alpha=0.1, gamma=0.9):
        '''初始化'''
        self.states = range(36)    # 狀態集。0~35 共36個狀態
        self.actions = list('udlr') # 動作集。上下左右  4個動作
        self.rewards = [0,-10,0,  0,  0, 0,
                        0,-10,0,  0,-10, 0,
                        0,-10,0,-10,  0, 0,
                        0,-10,0,-10,  0, 0,
                        0,-10,0,-10,  3, 0,
                        0,  0,0,-10,  0,10,] # 獎勵集。出口獎勵10,陷阱獎勵-10,元寶獎勵5
        self.hell_states = [1,7,13,19,25,15,31,37,43,10] # 陷阱位置
        
        self.alpha = alpha
        self.gamma = gamma
        
        self.q_table = pd.DataFrame(data=[[0 for _ in self.actions] for _ in self.states],
                                    index=self.states, 
                                    columns=self.actions)
    
    def save_policy(self):
        '''保存Q table'''
        with open('q_table.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.q_table, f, pickle.HIGHEST_PROTOCOL)
    
    def load_policy(self):
        '''導入Q table'''
        with open('q_table.pickle', 'rb') as f:
            self.q_table = pickle.load(f)
    
    def choose_action(self, state, epsilon=0.8):
        '''選擇相應的動作。根據當前狀態,隨機或貪婪,按照參數epsilon'''
        #if (random.uniform(0,1) > epsilon) or ((self.q_table.ix[state] == 0).all()):  # 探索
        if random.uniform(0,1) > epsilon:             # 探索
            action = random.choice(self.get_valid_actions(state))
        else:
            #action = self.q_table.ix[state].idxmax() # 利用 當有多個最大值時,會鎖死第一個!
            #action = self.q_table.ix[state].filter(items=self.get_valid_actions(state)).idxmax() # 重大改進!然鵝與上面一樣
            s = self.q_table.ix[state].filter(items=self.get_valid_actions(state))
            action = random.choice(s[s==s.max()].index) # 從可能有多個的最大值里面隨機選擇一個!
        return action
    
    def get_q_values(self, state):
        '''取給定狀態state的所有Q value'''
        q_values = self.q_table.ix[state, self.get_valid_actions(state)]
        return q_values
        
    def update_q_value(self, state, action, next_state_reward, next_state_q_values):
        '''更新Q value,根據貝爾曼方程'''
        self.q_table.ix[state, action] += self.alpha * (next_state_reward + self.gamma * next_state_q_values.max() - self.q_table.ix[state, action])
    
    def get_valid_actions(self, state):
        '''取當前狀態下所有的合法動作'''
        valid_actions = set(self.actions)
        if state % 6 == 5:               # 最后一列,則
            valid_actions -= set(['r'])  # 無向右的動作
        if state % 6 == 0:               # 最前一列,則
            valid_actions -= set(['l'])  # 無向左
        if state // 6 == 5:              # 最后一行,則
            valid_actions -= set(['d'])  # 無向下
        if state // 6 == 0:              # 最前一行,則
            valid_actions -= set(['u'])  # 無向上
        return list(valid_actions)
    
    def get_next_state(self, state, action):
        '''對狀態執行動作后,得到下一狀態'''
        #u,d,l,r,n = -6,+6,-1,+1,0
        if state % 6 != 5 and action == 'r':    # 除最后一列,皆可向右(+1)
            next_state = state + 1
        elif state % 6 != 0 and action == 'l':  # 除最前一列,皆可向左(-1)
            next_state = state - 1
        elif state // 6 != 5 and action == 'd': # 除最后一行,皆可向下(+2)
            next_state = state + 6
        elif state // 6 != 0 and action == 'u': # 除最前一行,皆可向上(-2)
            next_state = state - 6
        else:
            next_state = state
        return next_state
    
    def learn(self, env=None, episode=1000, epsilon=0.8):
        '''q-learning算法'''
        print('Agent is learning...')
        for i in range(episode):
            current_state = self.states[0]
            
            if env is not None: # 若提供了環境,則重置之!
                env.move_to(current_state)
                
            while current_state != self.states[-1]:
                current_action = self.choose_action(current_state, epsilon) # 按一定概率,隨機或貪婪地選擇
                next_state = self.get_next_state(current_state, current_action)
                next_state_reward = self.rewards[next_state]
                next_state_q_values = self.get_q_values(next_state)
                self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values)
                current_state = next_state
                
                #if next_state not in self.hell_states: # 非陷阱,則往前;否則待在原位
                #    current_state = next_state
                
                if env is not None: # 若提供了環境,則更新之!
                    env.move_to(current_state)
            print(i)
        print('\nok')
        
    def test(self):
        '''測試agent是否已具有智能'''
        count = 0
        current_state = self.states[0]
        while current_state != self.states[-1]:
            current_action = self.choose_action(current_state, 1.) # 1., 貪婪
            next_state = self.get_next_state(current_state, current_action)
            current_state = next_state
            count += 1
            
            if count > 36:   # 沒有在36步之內走出迷宮,則
                return False # 無智能
        
        return True  # 有智能
    
    def play(self, env=None, delay=0.5):
        '''玩游戲,使用策略'''
        assert env != None, 'Env must be not None!'
        
        if not self.test(): # 若尚無智能,則
            if pathlib.Path("q_table.pickle").exists():
                self.load_policy()
            else:
                print("I need to learn before playing this game.")
                self.learn(env, episode=1000, epsilon=0.5)
                self.save_policy()
        
        print('Agent is playing...')
        current_state = self.states[0]
        env.move_to(current_state, delay)
        while current_state != self.states[-1]:
            current_action = self.choose_action(current_state, 1.) # 1., 貪婪
            next_state = self.get_next_state(current_state, current_action)
            current_state = next_state
            env.move_to(current_state, delay)
        print('\nCongratulations, Agent got it!')


if __name__ == '__main__':
    env = Maze()    # 環境
    agent = Agent() # 個體(智能體)
    #agent.learn(env, episode=1000, epsilon=0.6) # 先學習
    #agent.save_policy()
    #agent.load_policy()
    agent.play(env)                             # 再玩耍
    
    #env.after(0, agent.learn, env, 1000, 0.8) # 先學
    #env.after(0, agent.save_policy) # 保存所學
    #env.after(0, agent.load_policy) # 導入所學
    #env.after(0, agent.play, env)            # 再玩
    env.mainloop()

 

重大改進:Agent.choose_action()。之前貪婪的時候直接用idxmax(),會鎖死第一個最大值對應的方向!

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM