问题描述
旅行商问题即TSP(traveling salesman problem),也就是求解最短汉密尔顿回路问题.
给定一个图G,要求找一条回路,使得该回路过每个顶点一次且仅一次,并且要让这条路最短.
关于遗传算法的几个概念
遗传算法模拟了达尔文自然选择,繁殖变异的过程.
- 种群:个体的集合.一开始需要设定种群的大小.在遗传算法中,种群的大小可以是固定长度的,也可以是变长的.总之,它是一个集合.
- cross:交叉,找两个个体(或者更多个体)让他们进行交配繁殖出新的下一代.那些比较优秀的个体能够以更大的概率获得更多的交配机会.遗传算法总是想当然的认为:优秀父母的基因组合之后依然优秀.
- mutate:变异,由于大自然中宇宙射线,各种生化反应会导致个体基因发生变异.在遗传算法中一定要注意,变异不会导致个体发生改变,而是会产生新的个体.在遗传算法中,一切个体一旦生成它的基因就不再发生改变.否则,好不容易求出来的最佳解可能一变异就消失了,导致算法收敛缓慢.一言以蔽之,变异就像是无性生殖,是个体自己复制了一个自己然后在复制的过程中发生了很多错乱.在进行交叉繁殖时,是优先选择优秀的个体,在变异中,每个个体人人平等,大家都有平等的概率来发生变异.
- fitness:每个个体基因与生俱来,不可改变.它的基因决定了它对环境的适应程度.对环境适应性强的个体有更多机会繁殖后代.也就是说在select的过程中,会以更大的概率选择优秀的个体作为父代.那么如何实现根据个体的适应程度来概率性地选择呢?这个就相当于一个几何概型,就像"转盘抽奖"一样.
- select:自然选择,根据fitness来选定优秀的个体.
遗传算法解决TSP问题的思路
(1)对问题的每一种可行解进行编码
这种编码其实就是把可行解用一个细长的东西来表达(这个东西就像染色体一样,上面带着许多基因).在TSP问题中,一个可行解的编码当然就是一个旅行序列,也就是顶点序列.在设计编码的时候,要考虑到如何控制交叉和变异,毕竟编码是要发生交叉和变异的,而交叉和变异也是最关键的部分.
(2)适应度fitness
每一个个体基因与生俱来无法改变,它的适应度值也是由基因计算得来无法改变.在TSP问题中,很显然路径花费越小,个体的环境适应性越强.也就是说,需要找到一个减函数f(x),使得适应性fitness=f(cost).减函数太多了,随便举一个就可以了,比如exp(-x),1/x等等.下面程序中使用了1/x.这样会导致个体之间的fitness相差较小,因为1/x随着x增大,减少的越来越慢,所以最好找一个形状合适的减函数,比如y=-x+b.
(3)选择概率p
把整个种群的fitness求个总和s,每个个体的选择概率就是person.fitness/s.然后就可以像转盘抽奖一样进行选择,以person.fitness/s的概率选择该个体去繁殖后代.
(4)变异mutate
任何一个可行解都是一个1~N的全排列,变异不就是随意shuffle几次吗.任意交换若干个数的位置即可.
(5)交叉cross
交叉大有文章,对于遗传算法适用的问题,交叉设计的好坏事关重大.对于不知道遗传算法是否管用的问题,交叉就是瞎整.比如,对于子代son,它的son.gene[3]取值以一定的概率取自父亲,以一定的概率取自母亲,如果这个基因与前面的某个基因重复,那么就从未使用的基因里面随机选取一个基因作为gene[3].
在我看来,遗传算法就是瞎几把整.
代码
python的思想就是快捷优雅高级,运行效率不是大事.
下面代码有很多大优化空间,但是优化之后代码就变多了.
比如在迭代过程中,每次至多产生一个新个体,这个个体需要插入到种群序列中去,并将另一个个体移除掉,这不需要全局排序,只需要从后往前来一次数组的插入操作即可.不过,这些都不是事.
再比如,getDis()函数每次都求一次距离,这距离当然可以先打表保存起来一个距离矩阵.
再比如,轮盘赌选择父代时,可以先累加一下存储起来,然后进行二分查找,可以从O(n)降到O(lgn)
import itertools
import math
import random
import matplotlib.pyplot as plt
from numpy.random import rand
N = 8 # 基因的长度,也就是城市的个数
g = rand(N, 2) * 10 # 随机产生N个城市的坐标
# 获取两个城市之间的距离
def getDis(i, j):
return math.hypot(g[i][0] - g[j][0], g[i][1] - g[j][1])
# 一个个体
class Person:
def __init__(self, gene=None):
if not gene:
gene = list(range(N))
random.shuffle(gene)
self.gene = gene
self.cost = sum([getDis(gene[i], gene[(i + 1) % N]) for i in range(N)])
self.fitness = 1 / self.cost
self.p = 0
def __str__(self):
return "{} fitness={} cost={}".format(str(self.gene), self.fitness, self.cost)
def __lt__(self, other):
return self.fitness > other.fitness
# 根据适应程度计算存活概率
def getP():
s = sum([person.fitness for person in people])
for person in people:
person.p = person.fitness / s
# 概率性选择一个适应力最强的个体
def select():
s = 0
p = random.random()
for person in people:
s += person.p
if s >= p: return person
# 交配繁殖
def cross(fa, mo):
gene = [0] * N
not_used = list(range(N))
for i in range(N):
if fa.gene[i] in not_used and mo.gene[i] in not_used:
gene[i] = fa.gene[i] if random.random() < 0.5 else mo.gene[i]
elif fa.gene[i] in not_used:
gene[i] = fa.gene[i]
elif mo.gene[i] in not_used:
gene[i] = mo.gene[i]
else:
gene[i] = not_used[random.randint(0, len(not_used) - 1)]
not_used.remove(gene[i])
return Person(gene)
# 变异,变异之后应该产生新的个体而不应该替换掉原来的个体
def mutate(person):
gene = [person.gene[i] for i in range(N)]
for i in range(random.randint(0, mutation_scale)):
x = random.randint(0, N - 1)
y = random.randint(0, N - 1)
gene[x], gene[y] = gene[y], gene[x]
return Person(gene)
#用全排列来求真正的答案,来检测结果正确性
def real_ans():
best = Person(list(range(N)))
for i in itertools.permutations(range(N)):
if best.cost > Person(i).cost:
best = Person(i)
return best
def draw(person, pos, title):
x, y = [g[i][0] for i in person.gene], [g[i][1] for i in person.gene]
mine = plt.subplot(pos, title=title + str(person.cost))
mine.plot(x, y, 'o-', linewidth=2, color='r')
people_size = 10 # 种群大小
people = [Person() for i in range(people_size)] # 种群
cross_probability = 0.5 # 交配的概率,决定了进化的速度
mutation_probability = 0.3 # 子代发生变异的概率
mutation_scale = N // 2 # 每次变异最多变异的基因数
generation_cnt = 1000 # 代数
def gene():
global people
for generation in range(generation_cnt):
getP()
if random.random() < cross_probability:
people.append(cross(select(), select()))
if random.random() < mutation_probability:
people.append(mutate(people[random.randint(0, people_size - 1)]))
people.sort()
people = people[0:people_size]
print(",".join([str(person.cost) for person in people]))
return people[0]
ans = gene()
true_ans = real_ans()
draw(ans, 121, "mine ")
draw(true_ans, 122, "real ans ")
plt.show()