1、算法描述
1、适应度函数
使用目标字符串的长度作为DNA的长度,使用ASCII码中字符串对应的数字表示DNA中的每个元素。范围为[32,126]
。
如果种群中某个个体和目标target对上的字母越多, fitness 就越高。
2、进化
select, mutate, crossover函数的功能和上一章博客所写内容相似,使用遗传算法寻找最高点,这里不再赘述。
2、代码
# @Time : 2021/2/14 12:03
# @Description : 句子配对
import numpy as np
TARGET_PHRASE = 'You get it!' # target DNA
POP_SIZE = 300 # population size
CROSS_RATE = 0.4 # mating probability (DNA crossover)
MUTATION_RATE = 0.01 # mutation probability
N_GENERATIONS = 1000
DNA_SIZE = len(TARGET_PHRASE)
# 字符串转数字,每一个字符都对应一个数字
# [ 89 111 117 32 103 101 116 32 105 116 33]
TARGET_ASCII = np.fromstring(TARGET_PHRASE, dtype=np.uint8) # convert string to number
ASCII_BOUND = [32, 126]
class GA(object):
def __init__(self, DNA_size, DNA_bound, cross_rate, mutation_rate, pop_size):
self.DNA_size = DNA_size
DNA_bound[1] += 1
self.DNA_bound = DNA_bound
self.cross_rate = cross_rate
self.mutate_rate = mutation_rate
self.pop_size = pop_size
self.pop = np.random.randint(*DNA_bound, size=(pop_size, DNA_size)).astype(np.int8) # int8 for convert to ASCII
def translateDNA(self, DNA): # convert to readable string
return DNA.tostring().decode('ascii')
def get_fitness(self): # count how many character matches
# 和目标字符串对应位置数据进行比较,位置相同的+1
match_count = (self.pop == TARGET_ASCII).sum(axis=1)
return match_count
def select(self):
fitness = self.get_fitness() + 1e-4 # add a small amount to avoid all zero fitness
idx = np.random.choice(np.arange(self.pop_size), size=self.pop_size, replace=True, p=fitness / fitness.sum())
return self.pop[idx]
def crossover(self, parent, pop):
if np.random.rand() < self.cross_rate:
i_ = np.random.randint(0, self.pop_size, size=1) # select another individual from pop
cross_points = np.random.randint(0, 2, self.DNA_size).astype(np.bool) # choose crossover points
parent[cross_points] = pop[i_, cross_points] # mating and produce one child
return parent
def mutate(self, child):
for point in range(self.DNA_size):
if np.random.rand() < self.mutate_rate:
child[point] = np.random.randint(*self.DNA_bound) # choose a random ASCII index
return child
def evolve(self):
pop = self.select()
pop_copy = pop.copy()
for parent in pop: # for every parent
child = self.crossover(parent, pop_copy)
child = self.mutate(child)
parent[:] = child
self.pop = pop
if __name__ == '__main__':
ga = GA(DNA_size=DNA_SIZE, DNA_bound=ASCII_BOUND, cross_rate=CROSS_RATE,
mutation_rate=MUTATION_RATE, pop_size=POP_SIZE)
for generation in range(N_GENERATIONS):
fitness = ga.get_fitness()
best_DNA = ga.pop[np.argmax(fitness)]
# best_DNA-- [103 53 39 94 122 114 73 32 85 116 61]
print("best_DNA--", best_DNA)
# 将数字形式的DNA转换为ASCII字符串
best_phrase = ga.translateDNA(best_DNA)
# best_phrase-- g5'^zrI Ut=
print("best_phrase--", best_phrase)
print('Gen', generation, ': ', best_phrase)
if best_phrase == TARGET_PHRASE:
break
ga.evolve()
结果展示-控制台输出:
Gen 0 : I64&ghA i0f
Gen 1 : I64&ghA i0f
Gen 2 : YoZ$gS[ U%r
Gen 3 : Yo4RgQ[ i%r
Gen 4 : YsJ ghh iUf
Gen 5 : IouughE itc
Gen 6 : IouughE itc
Gen 7 : IouughE itc
Gen 8 : IouughE itc
Gen 9 : souugSE itc
Gen 10 : Iou&geD ihc
Gen 11 : Iou~ghE its
Gen 12 : You&geD itc
Gen 13 : You&geD itc
Gen 14 : You&geD itc
Gen 15 : Youmge[ itG
Gen 16 : YouvgeA i(!
Gen 17 : You&geD itc
Gen 18 : zou ghG it!
Gen 19 : You geG i0!
Gen 20 : You geG i0!
Gen 21 : You ge6 it&
Gen 22 : You ge6 it&
Gen 23 : You gez it!
Gen 24 : You gez it!
Gen 25 : You gez it!
Gen 26 : You gez it!
Gen 27 : You gek it@
Gen 28 : You geG it!
Gen 29 : You geG it!
Gen 30 : You geG it!
Gen 31 : You geG it!
Gen 32 : You geG it!
Gen 33 : You ged it!
Gen 34 : You geA it!
Gen 35 : You geA it!
Gen 36 : You geA it!
Gen 37 : You ged it!
Gen 38 : You geA it!
Gen 39 : You ged it!
Gen 40 : You ged it!
Gen 41 : You get it!
3、测试代码
1、np.fromstring测试
TARGET_ASCII = np.fromstring(TARGET_PHRASE, dtype=np.uint8) # convert string to number
# [ 89 111 117 32 103 101 116 32 105 116 33]
print(TARGET_ASCII)
2、axis=1测试
print("---------sum(axis=1)测试-----------")
# pop = np.random.randint(*DNA_bound, size=(pop_size, DNA_size)).astype(np.int8)
pop = np.array([(89, 111, 117, 32, 103, 101, 116, 32, 105, 116, 33),(89, 111, 117, 32, 103, 101, 116, 32, 105, 116, 33)])
# (2, 11)
print(pop.shape)
match_count = (pop == TARGET_ASCII).sum(axis=1)
# [11 11]
print(match_count)