問題描述
旅行商問題即TSP(traveling salesman problem),也就是求解最短漢密爾頓回路問題.
給定一個圖G,要求找一條回路,使得該回路過每個頂點一次且僅一次,並且要讓這條路最短.
關於遺傳算法的幾個概念
遺傳算法模擬了達爾文自然選擇,繁殖變異的過程.
- 種群:個體的集合.一開始需要設定種群的大小.在遺傳算法中,種群的大小可以是固定長度的,也可以是變長的.總之,它是一個集合.
- cross:交叉,找兩個個體(或者更多個體)讓他們進行交配繁殖出新的下一代.那些比較優秀的個體能夠以更大的概率獲得更多的交配機會.遺傳算法總是想當然的認為:優秀父母的基因組合之后依然優秀.
- mutate:變異,由於大自然中宇宙射線,各種生化反應會導致個體基因發生變異.在遺傳算法中一定要注意,變異不會導致個體發生改變,而是會產生新的個體.在遺傳算法中,一切個體一旦生成它的基因就不再發生改變.否則,好不容易求出來的最佳解可能一變異就消失了,導致算法收斂緩慢.一言以蔽之,變異就像是無性生殖,是個體自己復制了一個自己然后在復制的過程中發生了很多錯亂.在進行交叉繁殖時,是優先選擇優秀的個體,在變異中,每個個體人人平等,大家都有平等的概率來發生變異.
- fitness:每個個體基因與生俱來,不可改變.它的基因決定了它對環境的適應程度.對環境適應性強的個體有更多機會繁殖后代.也就是說在select的過程中,會以更大的概率選擇優秀的個體作為父代.那么如何實現根據個體的適應程度來概率性地選擇呢?這個就相當於一個幾何概型,就像"轉盤抽獎"一樣.
- select:自然選擇,根據fitness來選定優秀的個體.
遺傳算法解決TSP問題的思路
(1)對問題的每一種可行解進行編碼
這種編碼其實就是把可行解用一個細長的東西來表達(這個東西就像染色體一樣,上面帶着許多基因).在TSP問題中,一個可行解的編碼當然就是一個旅行序列,也就是頂點序列.在設計編碼的時候,要考慮到如何控制交叉和變異,畢竟編碼是要發生交叉和變異的,而交叉和變異也是最關鍵的部分.
(2)適應度fitness
每一個個體基因與生俱來無法改變,它的適應度值也是由基因計算得來無法改變.在TSP問題中,很顯然路徑花費越小,個體的環境適應性越強.也就是說,需要找到一個減函數f(x),使得適應性fitness=f(cost).減函數太多了,隨便舉一個就可以了,比如exp(-x),1/x等等.下面程序中使用了1/x.這樣會導致個體之間的fitness相差較小,因為1/x隨着x增大,減少的越來越慢,所以最好找一個形狀合適的減函數,比如y=-x+b.
(3)選擇概率p
把整個種群的fitness求個總和s,每個個體的選擇概率就是person.fitness/s.然后就可以像轉盤抽獎一樣進行選擇,以person.fitness/s的概率選擇該個體去繁殖后代.
(4)變異mutate
任何一個可行解都是一個1~N的全排列,變異不就是隨意shuffle幾次嗎.任意交換若干個數的位置即可.
(5)交叉cross
交叉大有文章,對於遺傳算法適用的問題,交叉設計的好壞事關重大.對於不知道遺傳算法是否管用的問題,交叉就是瞎整.比如,對於子代son,它的son.gene[3]取值以一定的概率取自父親,以一定的概率取自母親,如果這個基因與前面的某個基因重復,那么就從未使用的基因里面隨機選取一個基因作為gene[3].
在我看來,遺傳算法就是瞎幾把整.
代碼
python的思想就是快捷優雅高級,運行效率不是大事.
下面代碼有很多大優化空間,但是優化之后代碼就變多了.
比如在迭代過程中,每次至多產生一個新個體,這個個體需要插入到種群序列中去,並將另一個個體移除掉,這不需要全局排序,只需要從后往前來一次數組的插入操作即可.不過,這些都不是事.
再比如,getDis()函數每次都求一次距離,這距離當然可以先打表保存起來一個距離矩陣.
再比如,輪盤賭選擇父代時,可以先累加一下存儲起來,然后進行二分查找,可以從O(n)降到O(lgn)
import itertools
import math
import random
import matplotlib.pyplot as plt
from numpy.random import rand
N = 8 # 基因的長度,也就是城市的個數
g = rand(N, 2) * 10 # 隨機產生N個城市的坐標
# 獲取兩個城市之間的距離
def getDis(i, j):
return math.hypot(g[i][0] - g[j][0], g[i][1] - g[j][1])
# 一個個體
class Person:
def __init__(self, gene=None):
if not gene:
gene = list(range(N))
random.shuffle(gene)
self.gene = gene
self.cost = sum([getDis(gene[i], gene[(i + 1) % N]) for i in range(N)])
self.fitness = 1 / self.cost
self.p = 0
def __str__(self):
return "{} fitness={} cost={}".format(str(self.gene), self.fitness, self.cost)
def __lt__(self, other):
return self.fitness > other.fitness
# 根據適應程度計算存活概率
def getP():
s = sum([person.fitness for person in people])
for person in people:
person.p = person.fitness / s
# 概率性選擇一個適應力最強的個體
def select():
s = 0
p = random.random()
for person in people:
s += person.p
if s >= p: return person
# 交配繁殖
def cross(fa, mo):
gene = [0] * N
not_used = list(range(N))
for i in range(N):
if fa.gene[i] in not_used and mo.gene[i] in not_used:
gene[i] = fa.gene[i] if random.random() < 0.5 else mo.gene[i]
elif fa.gene[i] in not_used:
gene[i] = fa.gene[i]
elif mo.gene[i] in not_used:
gene[i] = mo.gene[i]
else:
gene[i] = not_used[random.randint(0, len(not_used) - 1)]
not_used.remove(gene[i])
return Person(gene)
# 變異,變異之后應該產生新的個體而不應該替換掉原來的個體
def mutate(person):
gene = [person.gene[i] for i in range(N)]
for i in range(random.randint(0, mutation_scale)):
x = random.randint(0, N - 1)
y = random.randint(0, N - 1)
gene[x], gene[y] = gene[y], gene[x]
return Person(gene)
#用全排列來求真正的答案,來檢測結果正確性
def real_ans():
best = Person(list(range(N)))
for i in itertools.permutations(range(N)):
if best.cost > Person(i).cost:
best = Person(i)
return best
def draw(person, pos, title):
x, y = [g[i][0] for i in person.gene], [g[i][1] for i in person.gene]
mine = plt.subplot(pos, title=title + str(person.cost))
mine.plot(x, y, 'o-', linewidth=2, color='r')
people_size = 10 # 種群大小
people = [Person() for i in range(people_size)] # 種群
cross_probability = 0.5 # 交配的概率,決定了進化的速度
mutation_probability = 0.3 # 子代發生變異的概率
mutation_scale = N // 2 # 每次變異最多變異的基因數
generation_cnt = 1000 # 代數
def gene():
global people
for generation in range(generation_cnt):
getP()
if random.random() < cross_probability:
people.append(cross(select(), select()))
if random.random() < mutation_probability:
people.append(mutate(people[random.randint(0, people_size - 1)]))
people.sort()
people = people[0:people_size]
print(",".join([str(person.cost) for person in people]))
return people[0]
ans = gene()
true_ans = real_ans()
draw(ans, 121, "mine ")
draw(true_ans, 122, "real ans ")
plt.show()