粒子群算法(1) - Python實現


  • 抽象來源:模仿自然界中的鳥群覓食行為。
  • 核心思想:在自然界鳥群覓食過程中,我們可以想象食物自身散發某種着香味(實際上可能不是,此處僅以香味為例代表鳥群可能獲得的某種食物信息),該香味距離食物越近則越濃(以狀態函數值進行描述)。並且,我們假設群體中每只鳥的飛行行為均且僅受到三方面因素的影響和貢獻:1)每只鳥自身的飛行慣性 --- 自身慣性貢獻;2)每只鳥的歷史最優狀態 --- 自身認知貢獻;3)整個鳥群的歷史最優狀態 --- 群體經驗貢獻。注意,此處以與食物的距離(即香味濃度,也就是狀態函數值)來判斷飛行狀態的優劣。根據此三類貢獻,狀態空間內每只鳥將逐漸調整自身的飛行速度(包括大小、方向),並向食物位置(即局部香味最濃的位置)匯聚。因此,在相關參數設置合理的前提下,粒子群算法的最終解應該對應於給定狀態空間內的最值。
  • 迭代公式
    粒子群速度更新公式
    \begin{equation}
    V_i(t+1) = \omega V_i(t) + c_1r_1(pbest_i - X_i(t)) + c_2r_2(gbest-X_i(t))
    \end{equation}
    該式右端三項分別代表:自身慣性貢獻、自身認知貢獻以及群體經驗貢獻。其中,$\omega$代表慣性因子,$c_1$、$c_2$代表學習因子,$r_1$、$r_2$為$[0, 1]$之間的均勻隨機數,$pbest_i$為第$i$個粒子已知的歷史最優狀態或位置,$gbest$整個粒子群已知的歷史最優狀態或位置。$V_i(t)$與$X_i(t)$分別代表$t$時刻粒子$i$的速度與位置。
    由於實際問題可能處於多維空間內,因此有:
    \begin{equation}
    \begin{cases}
    V_i = (v_{i1}, v_{i2}, ..., v_{iD})\\
    X_i = (x_{i1}, x_{i2}, ..., x_{iD})
    \end{cases}
    \end{equation}
    其中,$D$為空間維數。
    粒子群位置更新公式
    \begin{equation}
    X_i(t+1) = X_i(t) + V_i(t+1)
    \end{equation} 
  • Python代碼實現
     1 import numpy as np
     2 import matplotlib.pyplot as plt
     3 import random
     4 
     5 
     6 # 定義“粒子”類
     7 class parti(object):
     8     def __init__(self, v, x):
     9         self.v = v                    # 粒子當前速度
    10         self.x = x                    # 粒子當前位置
    11         self.pbest = x                # 粒子歷史最優位置
    12         
    13 class PSO(object):
    14     def __init__(self, interval, tab='min', partisNum=10, iterMax=1000, w=1, c1=2, c2=2):
    15         self.interval = interval                                            # 給定狀態空間 - 即待求解空間
    16         self.tab = tab.strip()                                              # 求解最大值還是最小值的標簽: 'min' - 最小值;'max' - 最大值
    17         self.iterMax = iterMax                                              # 迭代求解次數
    18         self.w = w                                                          # 慣性因子
    19         self.c1, self.c2 = c1, c2                                           # 學習因子
    20         self.v_max = (interval[1] - interval[0]) * 0.1                      # 設置最大遷移速度
    21         #####################################################################
    22         self.partis_list, self.gbest = self.initPartis(partisNum)                 # 完成粒子群的初始化,並提取群體歷史最優位置
    23         self.x_seeds = np.array(list(parti_.x for parti_ in self.partis_list))    # 提取粒子群的種子狀態 ###
    24         self.solve()                                                              # 完成主體的求解過程
    25         self.display()                                                            # 數據可視化展示
    26         
    27     def initPartis(self, partisNum):
    28         partis_list = list()
    29         for i in range(partisNum):
    30             v_seed = random.uniform(-self.v_max, self.v_max)
    31             x_seed = random.uniform(*self.interval)
    32             partis_list.append(parti(v_seed, x_seed))
    33         temp = 'find_' + self.tab
    34         if hasattr(self, temp):                                             # 采用反射方法提取對應的函數
    35             gbest = getattr(self, temp)(partis_list)
    36         else:
    37             exit('>>>tab標簽傳參有誤:"min"|"max"<<<')
    38         return partis_list, gbest
    39         
    40     def solve(self):
    41         for i in range(self.iterMax):
    42             for parti_c in self.partis_list:
    43                 f1 = self.func(parti_c.x)
    44                 # 更新粒子速度,並限制在最大遷移速度之內
    45                 parti_c.v = self.w * parti_c.v + self.c1 * random.random() * (parti_c.pbest - parti_c.x) + self.c2 * random.random() * (self.gbest - parti_c.x)
    46                 if parti_c.v > self.v_max: parti_c.v = self.v_max
    47                 elif parti_c.v < -self.v_max: parti_c.v = -self.v_max
    48                 # 更新粒子位置,並限制在待解空間之內
    49                 if self.interval[0] <= parti_c.x + parti_c.v <=self.interval[1]:
    50                     parti_c.x = parti_c.x + parti_c.v 
    51                 else:
    52                     parti_c.x = parti_c.x - parti_c.v
    53                 f2 = self.func(parti_c.x)
    54                 getattr(self, 'deal_'+self.tab)(f1, f2, parti_c)             # 更新粒子歷史最優位置與群體歷史最優位置      
    55         
    56     def func(self, x):                                                       # 狀態產生函數 - 即待求解函數
    57         value = np.sin(x**2) * (x**2 - 5*x)
    58         return value
    59         
    60     def find_min(self, partis_list):                                         # 按狀態函數最小值找到粒子群初始化的歷史最優位置
    61         parti = min(partis_list, key=lambda parti: self.func(parti.pbest))
    62         return parti.pbest
    63         
    64     def find_max(self, partis_list):
    65         parti = max(partis_list, key=lambda parti: self.func(parti.pbest))   # 按狀態函數最大值找到粒子群初始化的歷史最優位置
    66         return parti.pbest
    67         
    68     def deal_min(self, f1, f2, parti_):
    69         if f2 < f1:                          # 更新粒子歷史最優位置
    70             parti_.pbest = parti_.x
    71         if f2 < self.func(self.gbest):
    72             self.gbest = parti_.x            # 更新群體歷史最優位置
    73             
    74     def deal_max(self, f1, f2, parti_):
    75         if f2 > f1:                          # 更新粒子歷史最優位置
    76             parti_.pbest = parti_.x
    77         if f2 > self.func(self.gbest):
    78             self.gbest = parti_.x            # 更新群體歷史最優位置
    79             
    80     def display(self):
    81         print('solution: {}'.format(self.gbest))
    82         plt.figure(figsize=(8, 4))
    83         x = np.linspace(self.interval[0], self.interval[1], 300)
    84         y = self.func(x)
    85         plt.plot(x, y, 'g-', label='function')
    86         plt.plot(self.x_seeds, self.func(self.x_seeds), 'b.', label='seeds')
    87         plt.plot(self.gbest, self.func(self.gbest), 'r*', label='solution')
    88         plt.xlabel('x')
    89         plt.ylabel('f(x)')
    90         plt.title('solution = {}'.format(self.gbest))
    91         plt.legend()
    92         plt.savefig('PSO.png', dpi=500)
    93         plt.show()
    94         plt.close()
    95 
    96         
    97 if __name__ == '__main__':
    98     PSO([-9, 5], 'max')
    View Code

     筆者所用示例函數為:
    \begin{equation}
    f(x) = (x^2 - 5x)sin(x^2)
    \end{equation}

  • 結果展示
  • 參考:https://wenku.baidu.com/view/0fdb3dff87c24028905fc321.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM