詳細
先上效果圖:
- 啟動界面
- 主界面
- 設置界面
- 服務器界面(使用highchart模板畫出每一局得分情況)
配置的兩款簡單小游戲以及訓練效果:
- 貪吃蛇
- “是男人就下一百層”(修改)
*原圖像太大被迫修改大小
使用說明:
【設置窗口】
→在上面的主界面中點擊倒三角形狀的鍵,屏幕上會彈出一個黑色的設置窗。在該窗口界面上,用戶可以通過拖動滑塊條、在框內輸入具體數值兩種方法設置模型參數。滑塊條和編輯框互聯。
【在服務器上查看訓練結果】
→點擊最小化按鈕,將會復制瀏覽器地址到剪切板上,可以將其粘貼到瀏覽器中實時監測訓練情況。窗口中的折線圖每隔五秒從temp.db數據庫中獲取更新的數據並加入到折線圖中,實施實時數據可視化。
【關閉按鈕】
→當點擊關閉按鈕時,若訓練次數超過1000幀,將會彈出窗口詢問是否保存記錄。否則會由於訓練次數過少,對訓練沒有意義而直接退出不保存結果,以提高效率。
→點擊確認
→成功保存
【新建模式訓練】
→選擇訓練游戲
→開始訓練(點擊播放按鈕)
→鼠標放在進度條上能看到具體數值
【加載模式訓練】
→點擊切換按鈕
→此時再點擊播放按鈕,會彈出窗口用於選擇加載模型
→點擊開始按鈕開始訓練,同時設置窗口按鈕、模式轉換按鈕都會失效,以確保訓練順利進行。
1、相關配置
- Python 3
- TensorFlow-gpu
- pygame
- OpenCV-Python
- PyQt5
- sys
- threading
- multiprocessing
- shelve
- os
- sqlite3
- socket
- pyperclip
- flask
- glob
- shutil
- numpy
- pandas
- time
- importlib
2、文件目錄
|————MyLibrary.py 用於設置游戲中人物等類
|————run_window.py 啟動主程序,包括啟動界面
|————mainwindow.py 主界面程序
|————setting.py 參數調節窗口程序
|————message_box.py 消息框窗口程序
|————DQL.py 人工智能主程序,負責選擇和啟動游戲、啟動深度強化學習內核
|————DQLBrain.py 深度強化學習內核
|————game_setting.py 存儲已有游戲決策狀態數、庫名等信息,新游戲加入必須將相關信息也加入在其中
|————flask_tk.py 服務器文件
|————jumpMan.py 跳跳人游戲文件
|————greedySnake.py 貪吃蛇游戲文件
|————resource 窗口圖片資源文件夾
|————save_networks 已得出的模型文件
|————templates
|————index.html 網頁前端模板文件
|————static
|————exporting.js
|————highcharts-zh_CN.js
|————highstock.js
|————jquery.js
|————temp.db 臨時數據庫,用於服務器和AI端數據交互使用
|————greedy_snake.data-00000-of-00001
|————greedy_snake.index
|————greedy_snake.meta 以上三個為一個訓練好的模型
|————greedy_snake.db.bak
|————greedy_snake.db.dat
|————greedy_snake.db.dir 以上三個為一個模型文件
|————setting_resource.py 設定窗口的資源文件
|————resource_message_box.py 消息框窗口的資源文件
|————resource.py 主窗口的資源文件
|————document.py 根據數據庫文件自動化生成報告
3、實現過程
整個demo主要分為四大部分:主窗口、算法和游戲內核、服務器以及管理版本數據庫文件部分。
- 啟動界面
import sys
from mainWindow import MAINWINDOW
from PyQt5.QtWidgets import QApplication,QSplashScreen
from PyQt5 import QtCore,QtGui,QtWidgets
if __name__ == '__main__':
app = QApplication(sys.argv)
#初始化啟動界面
splash=QtWidgets.QSplashScreen(QtGui.QPixmap("啟動界面.png"))
#展示啟動界面
splash.show()
#設置計時器
timer = QtCore.QElapsedTimer()
#計時器開始
timer.start()
#保證啟動界面出現3s
while timer.elapsed() < 3000:
app.processEvents()
#初始化主界面
MainWindow = MAINWINDOW()
#展示主界面
MainWindow.show()
#主界面完全加載后,啟動界面消失
splash.finish(MainWindow)
sys.exit(app.exec_())
- 主界面(均使用Qtdesigner完成)
import gameSetting
import resource
from PyQt5 import QtWidgets,QtCore,QtGui
from collections import deque
from threading import Thread
from multiprocessing import Process
import shelve
import sqlite3
import socket
import pyperclip
from DQL import AI
import setting
import messageBox
import webServers
import glob
import shutil
game_start=False
class myThread(Thread):
def __init__(self,game,model,replay_memory,timestep,setting):
Thread.__init__(self)
self.game=game
self.model=model
self.setting=setting
self.replay_memory=replay_memory
self.timestep=timestep
def run(self):
self.AI = AI(self.game,self.model,self.replay_memory,self.timestep,int(self.setting["Explore"]),float(self.setting["Initial"]),float(self.setting["Final"]),float(self.setting["Gamma"]),int(self.setting["Replay"]),int(self.setting["Batch"]),)
self.AI.playGame()
def stop(self):
self.AI.closeGame()
class MAINWINDOW(QtWidgets.QWidget):
def __init__(self, parent=None):
#父類初始化
super().__init__()
#主窗體對象初始化
self.setObjectName("Form")
self.setEnabled(True)
self.resize(681, 397)
self.setStyleSheet("background-color: rgb(255, 255, 255);")
self.setWindowFlags(QtCore.Qt.FramelessWindowHint)
#進度條初始化
self.progressBar = QtWidgets.QProgressBar(self)
self.progressBar.setEnabled(True)
self.progressBar.setGeometry(QtCore.QRect(140, 348, 291, 23))
self.progressBar.setProperty("value", 0)
self.progressBar.setTextVisible(False)
self.progressBar.setObjectName("progressxzBar")
#啟動按鈕初始化
self.control = QtWidgets.QPushButton(self)
self.control.setGeometry(QtCore.QRect(10, 325, 71, 71))
self.control.setStyleSheet("border-image: url(:/bottom/resource/開始按鈕.png);")
self.control.setText("")
self.control.setObjectName("control")
self.control_state=False
#下拉框初始化
self.game_selection = QtWidgets.QComboBox(self)
self.game_selection.setEnabled(True)
self.game_selection.setGeometry(QtCore.QRect(530, 343, 141, 31))
self.game_selection.setAutoFillBackground(False)
self.game_selection.setStyleSheet("QComboBox{border-image: url(:/list/resource/下拉框.png)} \n""QComboBox::drop-down {image: url(:/bottom/resource/下拉框按鈕.png) }")
self.game_selection.setEditable(False)
self.game_selection.setInsertPolicy(QtWidgets.QComboBox.NoInsert)
self.game_selection.setIconSize(QtCore.QSize(0, 0))
self.game_selection.setFrame(False)
self.game_selection.setObjectName("game_selection")
#模式選擇按鈕加載
self.mode = QtWidgets.QPushButton(self)
self.mode.setGeometry(QtCore.QRect(440, 340, 71, 41))
self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
self.mode.setText("")
self.mode.setObjectName("mode")
self.mode_state = False
#背景圖初始化
self.label = QtWidgets.QLabel(self)
self.label.setGeometry(QtCore.QRect(0, 0, 681, 331))
self.label.setStyleSheet("border-image: url(:/image/resource/Background.png);")
self.label.setText("")
self.label.setObjectName("label")
#設置按鈕初始化
self.setting = QtWidgets.QPushButton(self)
self.setting.setGeometry(QtCore.QRect(570, 10, 31, 21))
self.setting.setStyleSheet("border-image: url(:/bottom/resource/菜單.png);")
self.setting.setText("")
self.setting.setObjectName("setting")
#獲取ip地址按鈕初始化
self.pushButton_3 = QtWidgets.QPushButton(self)
self.pushButton_3.setGeometry(QtCore.QRect(610, 10, 31, 23))
self.pushButton_3.setStyleSheet("border-image: url(:/bottom/resource/最小化.png);")
self.pushButton_3.setText("")
self.pushButton_3.setObjectName("pushButton_3")
#關閉按鈕初始化
self.bottom_close = QtWidgets.QPushButton(self)
self.bottom_close.setGeometry(QtCore.QRect(650, 10, 21, 23))
self.bottom_close.setStyleSheet("border-image: url(:/bottom/resource/關閉.png);")
self.bottom_close.setText("")
self.bottom_close.setObjectName("bottom_close")
#重設界面
self.init_window(self)
#按鍵消息槽設置
self.connectBottom()
QtCore.QMetaObject.connectSlotsByName(self)
#初始化窗口
def init_window(self, Form):
_translate = QtCore.QCoreApplication.translate
Form.setWindowTitle(_translate("Form", "深度強化學習工具箱"))
#子窗口對象獲取
self.setting_form = setting. SETTING()
self.message_box=messageBox.MESSAGE_BOX()
#游戲列表加載
game_setting_dict = gameSetting.getSetting()
for i,game in enumerate(game_setting_dict.keys()):
self.game_selection.addItem("")
self.game_selection.setItemText(i, _translate("Form", game))
self.game_selection.setCurrentText(_translate("Form", list(game_setting_dict.keys())[0]))
self.game_selection.setCurrentIndex(0)
#啟動服務器
flask_process = Process(target=webServers.start)
flask_process.daemon = True
flask_process.start()
#統一實現按鍵與消息函數連接
def connectBottom(self):
self.control.clicked.connect(self.loadGame)
self.bottom_close.clicked.connect(self.closeWindow)
self.mode.clicked.connect(self.setMode)
self.setting.clicked.connect(self.openSetting)
self.pushButton_3.clicked.connect(self.getIp)
#界面可拖動設置
def mousePressEvent(self, event):
if event.button() == QtCore.Qt.LeftButton:
self.m_drag = True
self.m_DragPosition = event.globalPos() - self.pos()
event.accept()
self.setCursor(QtGui.QCursor(QtCore.Qt.OpenHandCursor))
def mouseMoveEvent(self, QMouseEvent):
if QtCore.Qt.LeftButton and self.m_drag:
self.move(QMouseEvent.globalPos() - self.m_DragPosition)
QMouseEvent.accept()
def mouseReleaseEvent(self, QMouseEvent):
self.m_drag = False
self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor))
#加載按鍵操作
def loadGame(self):
self.mode.setEnabled(False)
self.setting.setEnabled(False)
#開啟游戲標志
global game_start
game_start=True
#control_state為按鍵標志,false為還沒開始游戲,true為已經開始游戲。按鍵外形隨狀態改變
if self.control_state:
self.closeWindow()
else:
#改變按鍵狀態
self.control.setStyleSheet("border-image: url(:/bottom/resource/終止按鈕.png);")
self.control_state =True
#初始化AI需要的變量
self.program_name = ""
game=self.game_selection.currentText()
model = ""
replay_memory = deque()
self.actual_timestep=0
setting=self.setting_form.getSetting()
#如果導入已有項目文件,那么更新上述變量
if self.mode_state:
program_path = QtWidgets.QFileDialog.getOpenFileName(self, "請選擇你想要加載的項目",
"../",
"Model File (*.dat)")
try:
#獲取項目名字(無后綴,包含地址)
self.program_name=program_path[0][:-7]
#打開項目文件
with shelve.open(self.program_name+'.db') as f:
#加載項目信息
game=f["game"]
model = self.program_name
replay_memory = f["replay"]
setting=f["setting"]
self.actual_timestep = int(f["timestep"])
self.setting_form.updateSetting(setting)
self.update_dataset(f["result"])
except:
pass
#啟動游戲線程
self.game_thread = myThread(game,model,replay_memory,self.actual_timestep,setting)
self.game_thread.start()
#啟動狀態更新計時器
self.state_Timer = QtCore.QTimer()
self.state_Timer.timeout.connect(self.updateState)
self.state_Timer.start(5000)
#關閉窗口
def closeWindow(self):
timestep=0
#如果游戲根本沒啟動或者啟動時間過短,那么按退出鍵則直接退出
#這里用try是因為有時候游戲啟動太慢,超過五秒
try:
timestep=self.state["TIMESTEP"]
except:
pass
if timestep>1000:
#啟動對話框
reply = self.message_box.exec_()
if reply:
# 關閉游戲窗口
try:
self.game_thread.AI.closeGame()
except:
pass
#新建模式
if not self.program_name:
save_program_path = QtWidgets.QFileDialog.getSaveFileName(self, "請選擇你保存項目的位置",
"../",
"Program File(*.db)")
#確保完成了完整保存操作后再進行操作
if save_program_path:
#獲取保存的程序地址和名稱(無后綴)
program_name = save_program_path[0].split(".")[0]
#打開程序地址
self.saveProgram(save_program_path,0)
#保存模型
self.saveModel(program_name)
#加載模式
else:
program_name=self.program_name
try:
self.saveProgram(program_name+'.db',1)
except:
pass
#保存模型
self.saveModel(program_name)
#清空臨時數據庫
with sqlite3.connect('temp.db', check_same_thread=False) as f:
c = f.cursor()
c.execute('delete from scores')
f.commit()
#關閉主界面窗口並終止計時器、服務器線程
self.close()
#統一處理保存項目文件
def saveProgram(self,save_program_path,state):
with shelve.open(save_program_path[0]) as f:
# AI運行的設定
f["setting"] = self.setting_form.getSetting()
# AI運行的狀態
state = self.game_thread.AI.getState()
f["game"] = self.game_selection.currentText()
f["epsilon"] = state["EPSILON"]
f["result"] = [[i[0] * 1000, i[1]] for i in
sqlite3.connect('temp.db', check_same_thread=False).cursor().execute(
'select * from scores').fetchall()]
f["replay"] = self.game_thread.AI.getReplay()
if state:
f["timestep"]=int(state["TIMESTEP"]) + int(f["timestep"])
else:
f["timestep"] = state["TIMESTEP"]
#定時更新主窗口狀態
def updateState(self):
#嘗試獲取游戲狀態,如果啟動時間過慢仍未啟動則跳過此次獲取
try:
self.state = self.game_thread.AI.getState()
except:
pass
else:
actual_timestep=self.state["TIMESTEP"]
self.progressBar.setToolTip("Timestep:"+str(actual_timestep)+" STATE:"+self.state["STATE"]+" EPSILON:"+str(self.state["EPSILON"]))
self.progressBar.setProperty("value",min(float(actual_timestep)/float(self.setting_form.getSetting()["Explore"])*100,100))
#每隔5秒才向數據庫讀取一次,優化速度
try:
self.game_thread.AI.data_base.commit()
except:
pass
# 通過按鍵更改AI模式
def setMode(self):
if not self.mode_state:
self.mode_state = True
self.mode.setStyleSheet("border-image: url(:/bottom/resource/加載模式.png);\n""")
else:
self.mode_state = False
self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
# 獲取本機ip地址
def getIp(self):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('8.8.8.8', 80))
ip = sock.getsockname()[0]
finally:
sock.close()
pyperclip.copy(ip + ':9090')
#定時更新數據庫
def updateDataset(self,results):
with shelve.open('temp.db',writeback=True) as f:
c=f.cursor()
for result in results:
c.execute("insert into scores values (%s,%s)" % (result[0], result[1]))
f.commit()
# 保存模型
def saveModel(self, program_name):
for file in glob.glob("./saved_networks/network-dqn-*"):
postfix = file.split('.')[-1]
try:
shutil.copy(file, program_name + '.' + postfix)
except:
pass
# 設置按鍵操作
def openSetting(self):
self.setting_form.show()
- 設置窗口
from PyQt5 import QtCore, QtGui, QtWidgets
import setting_resource
class SETTING(QtWidgets.QWidget):
def __init__(self):
#父類初始化
super().__init__()
#主窗口初始化
self.setObjectName("Dialog")
self.resize(547, 402)
self.setStyleSheet("")
#初始化確定按鈕
self.pushButton = QtWidgets.QPushButton(self)
self.pushButton.setGeometry(QtCore.QRect(160, 320, 75, 23))
self.pushButton.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/設定確定按鈕.png);")
self.pushButton.setText("")
self.pushButton.setObjectName("pushButton")
#初始化取消按鈕
self.pushButton_2 = QtWidgets.QPushButton(self)
self.pushButton_2.setGeometry(QtCore.QRect(320, 320, 75, 23))
self.pushButton_2.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/設定取消按鈕.png);")
self.pushButton_2.setText("")
self.pushButton_2.setObjectName("pushButton_2")
#初始化各個編輯框
self.line_explore = QtWidgets.QLineEdit(self)
self.line_explore.setGeometry(QtCore.QRect(450, 60, 61, 20))
self.line_explore.setStyleSheet("color: rgb(0, 0, 0);")
self.line_explore.setObjectName("line_explore")
self.line_initial = QtWidgets.QLineEdit(self)
self.line_initial.setGeometry(QtCore.QRect(450, 100, 61, 20))
self.line_initial.setStyleSheet("color: rgb(0, 0, 0);")
self.line_initial.setObjectName("line_Initial")
self.line_final = QtWidgets.QLineEdit(self)
self.line_final.setGeometry(QtCore.QRect(450, 140, 61, 20))
self.line_final.setStyleSheet("color: rgb(0, 0, 0);")
self.line_final.setObjectName("line_final")
self.line_gamma = QtWidgets.QLineEdit(self)
self.line_gamma.setGeometry(QtCore.QRect(450, 180, 61, 20))
self.line_gamma.setStyleSheet("color: rgb(0, 0, 0);")
self.line_gamma.setObjectName("line_gamma")
self.line_replay = QtWidgets.QLineEdit(self)
self.line_replay.setGeometry(QtCore.QRect(450, 220, 61, 20))
self.line_replay.setStyleSheet("color: rgb(0, 0, 0);")
self.line_replay.setObjectName("line_replay")
self.line_batch = QtWidgets.QLineEdit(self)
self.line_batch.setGeometry(QtCore.QRect(450, 260, 61, 20))
self.line_batch.setStyleSheet("color: rgb(0, 0, 0);")
self.line_batch.setObjectName("line_batch")
self.exploreSlider = QtWidgets.QSlider(self)
self.exploreSlider.setGeometry(QtCore.QRect(120, 60, 300, 19))
self.exploreSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.exploreSlider.setMinimum(200000)
self.exploreSlider.setMaximum(10000000)
self.exploreSlider.setProperty("value", 200000)
self.exploreSlider.setOrientation(QtCore.Qt.Horizontal)
self.exploreSlider.setObjectName("exploreSlider")
self.label = QtWidgets.QLabel(self)
self.label.setGeometry(QtCore.QRect(50, 60, 48, 19))
self.label.setStyleSheet("color: rgb(255, 255, 255);")
self.label.setObjectName("label")
self.label_2 = QtWidgets.QLabel(self)
self.label_2.setGeometry(QtCore.QRect(50, 100, 48, 19))
self.label_2.setStyleSheet("color: rgb(255, 255, 255);")
self.label_2.setObjectName("label_2")
self.initialSlider = QtWidgets.QSlider(self)
self.initialSlider.setGeometry(QtCore.QRect(120, 100, 300, 19))
self.initialSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.initialSlider.setMaximum(1000)
self.initialSlider.setProperty("value", 0)
self.initialSlider.setOrientation(QtCore.Qt.Horizontal)
self.initialSlider.setObjectName("initialSlider")
self.label_3 = QtWidgets.QLabel(self)
self.label_3.setGeometry(QtCore.QRect(50, 140, 42, 19))
self.label_3.setStyleSheet("color: rgb(255, 255, 255);")
self.label_3.setObjectName("label_3")
self.finalSlider = QtWidgets.QSlider(self)
self.finalSlider.setGeometry(QtCore.QRect(120, 140, 300, 19))
self.finalSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.finalSlider.setMaximum(1000)
self.finalSlider.setProperty("value", 0)
self.finalSlider.setOrientation(QtCore.Qt.Horizontal)
self.finalSlider.setObjectName("finalSlider")
self.label_4 = QtWidgets.QLabel(self)
self.label_4.setGeometry(QtCore.QRect(50, 180, 42, 19))
self.label_4.setStyleSheet("color: rgb(255, 255, 255);")
self.label_4.setObjectName("label_4")
self.gammaSlider = QtWidgets.QSlider(self)
self.gammaSlider.setGeometry(QtCore.QRect(120, 180, 300, 19))
self.gammaSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.gammaSlider.setMaximum(100)
self.gammaSlider.setProperty("value", 99)
self.gammaSlider.setOrientation(QtCore.Qt.Horizontal)
self.gammaSlider.setObjectName("gammaSlider")
self.label_6 = QtWidgets.QLabel(self)
self.label_6.setGeometry(QtCore.QRect(50, 220, 42, 19))
self.label_6.setStyleSheet("color: rgb(255, 255, 255);")
self.label_6.setObjectName("label_6")
self.replaySlider = QtWidgets.QSlider(self)
self.replaySlider.setGeometry(QtCore.QRect(120, 220, 300, 19))
self.replaySlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.replaySlider.setMaximum(100000)
self.replaySlider.setProperty("value", 50000)
self.replaySlider.setOrientation(QtCore.Qt.Horizontal)
self.replaySlider.setObjectName("replaySlider")
self.label_7 = QtWidgets.QLabel(self)
self.label_7.setGeometry(QtCore.QRect(50, 260, 36, 19))
self.label_7.setStyleSheet("color: rgb(255, 255, 255);")
self.label_7.setObjectName("label_7")
self.batchSlider = QtWidgets.QSlider(self)
self.batchSlider.setGeometry(QtCore.QRect(120, 260, 300, 19))
self.batchSlider.setStyleSheet("QSlider::handle:horizontal { \n"" image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal { \n"" image: url(:/image/resource/Base.png);\n""}\n""")
self.batchSlider.setMaximum(100)
self.batchSlider.setProperty("value", 32)
self.batchSlider.setOrientation(QtCore.Qt.Horizontal)
self.batchSlider.setObjectName("batchSlider")
self.label_5 = QtWidgets.QLabel(self)
self.label_5.setGeometry(QtCore.QRect(0, 0, 551, 411))
self.label_5.setStyleSheet("background-image: url(:/background/resource/設定背景.png);")
self.label_5.setText("")
self.label_5.setObjectName("label_5")
#組件掛起待用
self.label_5.raise_()
self.pushButton.raise_()
self.pushButton_2.raise_()
self.line_explore.raise_()
self.line_initial.raise_()
self.line_final.raise_()
self.line_gamma.raise_()
self.line_replay.raise_()
self.line_batch.raise_()
self.exploreSlider.raise_()
self.label.raise_()
self.label_2.raise_()
self.initialSlider.raise_()
self.label_3.raise_()
self.finalSlider.raise_()
self.label_4.raise_()
self.gammaSlider.raise_()
self.label_6.raise_()
self.replaySlider.raise_()
self.label_7.raise_()
self.batchSlider.raise_()
#重設界面
self.retranslateUi(self)
#編輯框和滑條互聯
self.connect()
#按鈕消息槽激活
self.pushButton.clicked.connect(self.saveSetting)
self.pushButton_2.clicked.connect(self.cancel)
QtCore.QMetaObject.connectSlotsByName(self)
def retranslateUi(self, Dialog):
_translate = QtCore.QCoreApplication.translate
Dialog.setWindowTitle(_translate("Dialog", "設置"))
#初始化各編輯框
self.line_explore.setText(_translate("Dialog", "200000"))
self.line_initial.setText(_translate("Dialog", "0"))
self.line_final.setText(_translate("Dialog", "0"))
self.line_gamma.setText(_translate("Dialog", "0.99"))
self.line_replay.setText(_translate("Dialog", "50000"))
self.line_batch.setText(_translate("Dialog", "32"))
self.label.setText(_translate("Dialog", "Explore:"))
self.label_2.setText(_translate("Dialog", "Initial:"))
self.label_3.setText(_translate("Dialog", "Final:"))
self.label_4.setText(_translate("Dialog", "Gamma:"))
self.label_6.setText(_translate("Dialog", "Replay:"))
self.label_7.setText(_translate("Dialog", "Batch:"))
#初始化設定
self.setting={"Explore":200000,"Initial":0,"Final":0,"Gamma":0.99,"Replay":50000,"Batch":32}
#編輯框和滑動條互聯
def connect(self):
self.exploreSlider.valueChanged.connect(self.changeLineExplore)
self.line_explore.textChanged.connect(self.changeSliderExplore)
self.initialSlider.valueChanged.connect(self.changeLineInitial)
self.line_initial.textChanged.connect(self.changeSliderInitial)
self.finalSlider.valueChanged.connect(self.changeLineFinal)
self.line_final.textChanged.connect(self.changeSliderFinal)
self.gammaSlider.valueChanged.connect(self.changeLineGamma)
self.line_gamma.textChanged.connect(self.changeSliderGamma)
self.replaySlider.valueChanged.connect(self.changeLineReplay)
self.line_replay.textChanged.connect(self.changeSliderReplay)
self.batchSlider.valueChanged.connect(self.changeLineBatch)
self.line_batch.textChanged.connect(self.changeSliderBatch)
def changeLineExplore(self):
try:
self.line_explore.setText(str(self.exploreSlider.value()))
except:
pass
def changeSliderExplore(self):
try:
self.exploreSlider.setValue(int(self.line_explore.text()))
except:
pass
def changeLineInitial(self):
try:
self.line_initial.setText(str(self.initialSlider.value()/1000))
except:
pass
def changeSliderInitial(self):
try:
self.initialSlider.setValue(int(float(self.line_initial.text())*1000))
except:
pass
def changeLineFinal(self):
try:
self.line_final.setText(str(self.finalSlider.value()/1000))
except:
pass
def changeSliderFinal(self):
try:
self.finalSlider.setValue(int(float(self.line_final.text()*1000)))
except:
pass
def changeLineGamma(self):
try:
self.line_gamma.setText(str(self.gammaSlider.value()/100))
except:
pass
def changeSliderGamma(self):
try:
self.gammaSlider.setValue(int(100*float(self.line_gamma.text())))
except:
pass
def changeLineReplay(self):
try:
self.line_replay.setText(str(self.replaySlider.value()))
except:
pass
def changeSliderReplay(self):
try:
self.replaySlider.setValue(int(self.line_replay.text()))
except:
pass
def changeLineBatch(self):
try:
self.line_batch.setText(str(self.batchSlider.value()))
except:
pass
def changeSliderBatch(self):
try:
self.batchSlider.setValue(int(self.line_batch.text()))
except:
pass
#外部獲取AI設置
def getSetting(self):
return self.setting
#保存設定
def saveSetting(self):
self.setting={"Explore":self.line_explore.text(),"Initial":self.line_initial.text(),"Final":self.line_final.text(),"Gamma":self.line_gamma.text(),"Replay":self.line_replay.text(),"Batch":self.line_batch.text()}#還要做一個數字判斷
self.hide()
#取消設定
def cancel(self):
self.hide()
return 0
#通過導入文檔更新設定
def updateSetting(self,setting):
self.setting={"Explore":setting["Explore"],"Initial":setting["Initial"],"Final":setting["Final"],"Gamma":setting["Gamma"],"Replay":setting["Replay"],"Batch":setting["Batch"]}#還要做一個數字判斷
self.line_explore.setText(str(setting["Explore"]))
self.line_final.setText(str(setting["Final"]))
self.line_Initial.setText(str(setting["Initial"]))
self.line_gamma.setText(str(setting["Gamma"]))
self.line_replay.setText(str(setting["Replay"]))
self.line_batch.setText(str(setting["Batch"]))
-
深度強化學習
該部分代碼參考https://blog.csdn.net/songrotek/article/details/50951537。 深度強化學習原理我這里不再贅述,大家可以查看該blog,有很詳細的講解。
主要由兩部分組成:DQL.py統一管理游戲和算法,DQLBrain.py則是深度強化學習算法核心。下面分別展示: -
-
DQL.py
import cv2
from DQLBrain import Brain
import numpy as np
from collections import deque
import sqlite3
import pygame
import time
import gameSetting
import importlib#所有游戲的統一設置 SCREEN_X = 288 SCREEN_Y = 512 FPS = 60 class AI: def __init__(self, title,model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size): #初始化常量 self.scores = deque() self.games_info = gameSetting.getSetting() #連接臨時數據庫(並確保已經存在對應的表) self.data_base = sqlite3.connect('temp.db', check_same_thread=False) self.c = self.data_base.cursor() try: self.c.execute('create table scores (time integer, score integer) ') except: pass #創建Deep-Reinforcement Learning對象 self.brain = Brain(self.games_info[title]["action"],model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size) #創建游戲窗口 self.startGame(title,SCREEN_X,SCREEN_Y) #加載對應的游戲 game=importlib.import_module(self.games_info[title]['class']) self.game=game.Game(self.screen) def startGame(self,title,SCREEN_X, SCREEN_Y): #窗口的初始化 pygame.init() screen_size = (SCREEN_X, SCREEN_Y) pygame.display.set_caption(title) #屏幕的創建 self.screen = pygame.display.set_mode(screen_size) #游戲計時器的創建 self.clock = pygame.time.Clock() #為降低畫面復雜度,將畫面進行預處理 def preProcess(self, observation): #將512*288的畫面裁剪為80*80並將RGB(三通道)畫面轉換成灰度圖(一通道) observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY) #將非黑色的像素都變成白色 threshold,observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY) #返回(80,80,1),最后一維是保證圖像是一個tensor(張量),用於輸入tensorflow return np.reshape(observation, (80, 80, 1)) #開始游戲 def playGame(self): #先隨便給一個決策輸入,啟動游戲 observation0, reward0, terminal,score =self.game.frameStep(np.array([1, 0, 0])) observation0 = self.preProcess(observation0) self.brain.setInitState(observation0[:,:,0]) #開始正式游戲 i = 1 while True: i = i + 1 action = self.brain.getAction() next_bservation, reward, terminal,score = self.game.frameStep(action) #處理游戲界面銷毀消息 if (terminal == -1): self.closeGame() return else: #繼續游戲 next_bservation = self.preProcess(next_bservation) self.brain.setPerception(next_bservation, action, reward, terminal) #提取每一局的成績 if terminal: t = int(time.time()) self.c.execute("insert into scores values (%s,%s)" % (t, score)) #關閉游戲 def closeGame(self): pygame.quit() self.brain.close() time.sleep(0.5)#確保brain中寫入數據庫的操作已經完成 self.data_base.close() #獲得當前游戲狀態 def getState(self): return self.brain.getState() #獲得當前replay數據,以加入項目文件 def getReplay(self): return self.brain.replay_memory
-
-
-
DQLBrain.py
observe=100class Brain: def __init__(self, actions,model_path,replay_memory=deque(),current_timestep=0,explore=200000.,initial_epsilon=0.0,final_epsilon=0.0,gamma=0.99,replay_size=50000,batch_size=32): # 設置超參數: # 學習率 self.gamma = gamma # 訓練之前觀察的次數 self.observe = observe # 容錯率下降的次數 self.explore = explore # 一開始的容錯率 self.initial_epsilon = initial_epsilon #最終的容錯率 self.final_epsilon = final_epsilon # replay buffer的大小 self.replay_size = replay_size # minibatch的大小 self.batch_size = batch_size self.update_time = 100 self.whole_state = dict() #初始化replay buffer self.replay_memory = replay_memory # 初始化其他參數 self.timestep = 0 self.initial_timestep=current_timestep self.accual_timestep=self.initial_timestep+self.timestep #當主界面采用加載模式時,算法核心必須重新加載項目文件中的已經記錄的容錯率 self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep if self.epsilon<self.final_epsilon: self.epsilon=self.final_epsilon self.actions = actions # 初始化 Q_t+1 網絡 self.state_input, self.QValue, self.conv1_w, self.conv1_b, self.conv2_w, self.conv2_b, self.conv3_w, self.conv3_b, self.fc1_w, self.fc1_b, self.fc2_w, self.fc2_b = self.createQNetwork() # 初始化 Q_t 網絡 self.state_inputT, self.QValueT, self.conv1_wT, self.conv1_bT, self.conv2_wT, self.conv2_bT, self.conv3_wT, self.conv3_bT, self.fc1_wT, self.fc1_bT, self.fc2_wT, self.fc2_bT = self.createQNetwork() self.copyTargetQNetwork = [self.conv1_wT.assign(self.conv1_w), self.conv1_bT.assign(self.conv1_b), self.conv2_wT.assign(self.conv2_w), self.conv2_bT.assign(self.conv2_b), self.conv3_wT.assign(self.conv3_w), self.conv3_bT.assign(self.conv3_b), self.fc1_wT.assign(self.fc1_w), self.fc1_bT.assign(self.fc1_b), self.fc2_wT.assign(self.fc2_w), self.fc2_bT.assign(self.fc2_b)] #損失函數的設置 self.action_input = tf.placeholder("float", [None, self.actions]) self.y_input = tf.placeholder("float", [None]) Q_Action = tf.reduce_sum(tf.multiply(self.QValue, self.action_input), reduction_indices=1) self.cost = tf.reduce_mean(tf.square(self.y_input - Q_Action)) self.optimizer = tf.train.AdamOptimizer(1e-6).minimize(self.cost) # 保存和重新加載模型 self.saver = tf.train.Saver(max_to_keep=1) self.session = tf.InteractiveSession() self.session.run(tf.initialize_all_variables()) def createQNetwork(self): # 初始化結構 # 第一層卷積層 8*8*4*32 W_conv1 = self.weightVariable([8, 8, 4, 32]) b_conv1 = self.biasVariable([32]) # 第二層卷積層 4*4*32*64: W_conv2 = self.weightVariable([4, 4, 32, 64]) b_conv2 = self.biasVariable([64]) #第三層卷積層 3*3*64*64 W_conv3 = self.weightVariable([3, 3, 64, 64]) b_conv3 = self.biasVariable([64]) #全連接層1600*512 W_fc1 = self.weightVariable([1600, 512]) b_fc1 = self.biasVariable([512]) #輸出層 512*actions W_fc2 = self.weightVariable([512, self.actions]) b_fc2 = self.biasVariable([self.actions]) # input layer stateInput = tf.placeholder("float", [None, 80, 80, 4]) # 開始建立網絡 # 隱藏層 h_conv1 = tf.nn.relu(self.conv2d(stateInput, W_conv1, 4) + b_conv1) #20*20*32 to 10*10*32 h_pool1 = self.maxPool_2x2(h_conv1) h_conv2 = tf.nn.relu(self.conv2d(h_pool1, W_conv2, 2) + b_conv2) #stride=1,5*5*64 to 5*5*64 h_conv3 = tf.nn.relu(self.conv2d(h_conv2, W_conv3, 1) + b_conv3) #5*5*64 to 1*1600 h_conv3_flat = tf.reshape(h_conv3, [-1, 1600]) h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1) #輸出層 QValue = tf.matmul(h_fc1, W_fc2) + b_fc2 return stateInput, QValue, W_conv1, b_conv1, W_conv2, b_conv2, W_conv3, b_conv3, W_fc1, b_fc1, W_fc2, b_fc2 def trainQNetwork(self): #從replay buffer中抽樣 minibatch = random.sample(self.replay_memory, self.batch_size) state_batch = [data[0] for data in minibatch] action_batch = [data[1] for data in minibatch] reward_batch = [data[2] for data in minibatch] nextState_batch = [data[3] for data in minibatch] #計算損失函數 y_batch = [] QValue_batch = self.QValueT.eval(feed_dict={self.state_inputT: nextState_batch}) for i in range(0, self.batch_size): terminal = minibatch[i][4] if terminal: y_batch.append(reward_batch[i]) else: y_batch.append(reward_batch[i] + self.gamma * np.max(QValue_batch[i])) self.optimizer.run(feed_dict={self.y_input: y_batch, self.action_input: action_batch, self.state_input: state_batch}) # 每運行100epoch保存一次網絡 if self.timestep % 1000 == 0: self.saver.save(self.session, './saved_networks/network' + '-dqn', global_step=self.timestep+self.initial_timestep) #更新Q網絡 if self.timestep % self.update_time == 0: self.session.run(self.copyTargetQNetwork) def setPerception(self, nextObservation, action, reward, terminal): new_state = np.append(self.current_state[:, :, 1:], nextObservation, axis=2) self.replay_memory.append((self.current_state, action, reward, new_state, terminal)) #控制replay buffer的大小 if len(self.replay_memory) > self.replay_size: self.replay_memory.popleft() if self.timestep > self.observe: self.trainQNetwork() # 將訓練信息輸出到主界面中 if self.timestep <= self.observe: state = "observe" elif self.timestep > self.observe and self.timestep <= self.observe + self.explore: state = "explore" else: state = "train" self.whole_state={"TIMESTEP":self.timestep +self.initial_timestep,"STATE":state, "EPSILON":self.epsilon,"ACTUAL":int(self.timestep+self.initial_timestep)} self.current_state = new_state self.timestep += 1 def getAction(self): QValue = self.QValue.eval(feed_dict={self.state_input: [self.current_state]})[0] action = np.zeros(self.actions) #epsilon策略 if random.random() <= self.epsilon: action_index = random.randrange(self.actions) action[action_index] = 1 else: action_index = np.argmax(QValue) action[action_index] = 1 # 改變episilon if self.epsilon > self.final_epsilon and self.accual_timestep > self.observe: self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep return action def setInitState(self, observation): self.current_state = np.stack((observation, observation, observation, observation), axis=2) def weightVariable(self, shape): initial = tf.truncated_normal(shape, stddev=0.01) return tf.Variable(initial) def biasVariable(self, shape): initial = tf.constant(0.01, shape=shape) return tf.Variable(initial) def conv2d(self, x, W, stride): return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME") def maxPool_2x2(self, x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") def close(self): self.session.close() def getState(self): return self.whole_state
-
-
服務器
主要采用highchart的API。在static文件夾中放好上述的四項文件后,在template文件夾中寫好服務器界面的代碼index.html(為了方便大家學習,界面寫得相當簡陋hh):
<head>
<script src='/static/jquery.js'></script>
<script src='/static/highstock.js'></script>
<script src='/static/exporting.js'></script>
</head>
<body>
<div id="container" style="min-width:310px;height:400px"></div>
<script>
$(function () {
// 使用當前時區,否則東八區會差八個小時
Highcharts.setOptions({
global: {
useUTC: false
}
});
$.getJSON('/data', function (data) {
// Create the chart
$('#container').highcharts('StockChart', {
chart:{
events:{
load:function(){
var series = this.series[0]
setInterval(function(){
$.getJSON('/data',function(res){
$.each(res,function(i,v){
series.addPoint(v)
})
})
},3000)
}
}
},
rangeSelector : {
selected : 1
},
title : {
text : '每局分數'
},
series : [{
name : '訓練表現',
data : data,
tooltip: {
valueDecimals: 2
}
}]
});
});
});
</script>
</body>
</html>
同時還需要編寫一個實時調用該模板的py文件:Webservice.py:
from flask import Flask,render_template,request
import sqlite3
import json
app=Flask(__name__)
#連接臨時數據庫
data_base = sqlite3.connect('temp.db', check_same_thread=False)
c = data_base.cursor()
#設置前端模板
@app.route('/')
def index():
return render_template("index.html")
#設置數據來源
@app.route('/data')
def data():
global tmp_time,c
sql='select * from scores'
c.execute(sql)
arr=[]
for i in c.fetchall():
arr.append([i[0]*1000,i[1]])
return json.dumps(arr)
#啟動服務器並設定端口,設置0.0.0.0表示對內網所有主機都進行監聽
def start():
app.run(host='0.0.0.0',port=9090)
結語
不過貌似PyQt5和tensorflow會有沖突,因此實際運行的時候會偶爾出現崩潰。另外服務器無法由外網的機器連接。如果大家知道怎么解決這些問題請在下方留言告訴我,謝謝!最后再來一次:github地址為https://github.com/qq303067814/DQLearning-Toolbox, 如果講解中有部分還想繼續了解的話可以直接查看源代碼,或者在留言中提出。訓練簡單小游戲的強化學習工具箱
注:本文著作權歸作者,由demo大師代發,拒絕轉載,轉載需要作者授權