【人工智能導論】A*算法求解八數碼問題


A*算法是一種啟發式搜索算法,它的關鍵在於,每次從open表中選取結點時,要按特定的策略選取。該策略如下所述:

  引入估值函數, f(n)是結點n的函數,f(n)越小,就意味着從初始狀態節點S通過結點n的路徑長度的估值最短。簡而言之,f(n)越小,則通過結點n的路徑是最佳路徑的可能性越大。

 

  因此,從open表中選取結點之前,我們先以f(n)為指標對open表中的所有節點進行排序,然后選取f(n)最小者即可。

    具體地,在本題中

 

 

 

 

 

以下是完整代碼

  1 # from queue import Queue
  2 from 八數碼問題BFS import move0  # move0函數參考上一篇博文:https://www.cnblogs.com/vivlalib/p/12557518.html
  3 import numpy as np
  4 import time
  5 
  6 class Node:
  7     def __init__(self, state, parent, operator):
  8         self.state = state
  9         self.parent = parent
 10         self.operator = operator
 11     def operate(self, dir):  # 操作算符
 12         new_state = move0(self.state, dir)
 13         if new_state is False:
 14             return None
 15         else:
 16             return Node(new_state, self, dir)  # 以self為父
 17     def islegal(self):  # 檢查結點的合法性, 根據實際問題改變
 18         return self
 19     def traverse(self):
 20         cur = self
 21         res = []
 22         while cur is not None:
 23             # print(cur.state)
 24             res.append(cur)
 25             cur = cur.parent
 26         # res.pop()
 27         return res
 28     def depth(self):  # 當前結點到根節點的距離,根據實際問題改變
 29         return len(self.traverse()) - 1
 30 
 31 def findn(state, n):  # 在state中尋找n的位置
 32     for i in range(len(state)):
 33         for j in range(len(state)):
 34             if state[i][j] == n:
 35                 ind0 = [i, j]
 36     return ind0
 37 def manhattan(pos1, pos2):  # pos1, pos2 are like [x, y]
 38     return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
 39 
 40 
 41 def distance0(state1, state2):  # 自己亂寫的距離函數,很慢  15s
 42     state1 = np.array(state1)
 43     state2 = np.array(state2)
 44     ret = np.sum(np.abs(state1-state2))  # ∑|state1[i] - state2[i]|  對應維度差的絕對值之和,即曼哈頓距離
 45     return ret
 46 
 47 def distance(state1, state2):  # 這里使用曼哈頓距離,  0.21s
 48 
 49     posn1, posn2 = [], []
 50     for n in range(9):  # 分別在state1與state2中查找0到9的位置
 51         posn1.append(findn(state1, n))
 52         posn2.append(findn(state2, n))
 53     dsum = 0
 54     for i in range(9):
 55         dsum += manhattan(posn1[i], posn2[i])
 56     return dsum
 57 
 58 
 59 
 60 
 61 def f(node:Node, goal_state):  # 估價函數
 62     """
 63     公式表示為: f(n)=g(n)+h(n),
 64     其中 f(n) 是從初始點經由節點n到目標點的估價函數,
 65     g(n) 是在狀態空間中從初始節點到n節點的實際代價,
 66     h(n) 是從n到目標節點最佳路徑的估計代價。
 67     :param node:
 68     :return:
 69     """
 70     cur_state = np.array(node.state)
 71     goal_state = np.array(goal_state)
 72     gn = node.depth()
 73     hn = distance(cur_state, goal_state)
 74     return gn + hn
 75 
 76 
 77 def showinfo(node: Node):  # 用於搜索成功后的輸出
 78     nlist = node.traverse()
 79     nlist.reverse()
 80     for n in nlist:
 81         if n.operator is not None:
 82             print(n.operator)
 83         print(n.state)
 84 
 85 
 86 def Astar(init_state, goal_state):
 87     dirs = ['up', 'down', 'left', 'right']
 88     open = []  # open表
 89     closed = []
 90     root = Node(init_state, None, None)  # 起始狀態,根節點
 91     open.append(root)
 92     while open:  # 若open表為空,退出循環
 93         open.sort(key=lambda x: f(x, goal_state))
 94         node = open.pop(0)  # 選出open表中f(n)最小者
 95         closed.append(node)
 96         # 擴展
 97         for dir in dirs:
 98             node_tmp = node.operate(dir)  # 以dir為方向拓展結點
 99             # 新節點合法
100             if node_tmp is not None and node_tmp not in closed:
101                 open.append(node_tmp)
102                 if node_tmp.state == goal_state:  # node_tmp 是目標結點?
103                     print('搜索成功。')
104                     showinfo(node_tmp)
105                     return True
106     return False
107 
108 goalss = [[1, 2, 3],
109         [8, 0, 4],
110         [7, 6, 5]]
111 goal = [[2,0,3],
112         [1,8,4],
113         [7,6,5]]
114 init = [[2,8,3],
115         [1,0,4],
116         [7,6,5]]
117 
118 goal = [[1, 2, 3],
119         [8, 0, 4],
120         [7, 6, 5]]
121 
122 init = [[2,8,0],
123         [1,6,3],
124         [7,5,4]]
125 if __name__ == '__main__':
126 
127     print(Astar(init,goal))
128     print(time.process_time())

運行結果如下:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM