import numpy as np import matplotlib.pyplot as plt N_CITIES = 20 CROSS_RATE = 0.1 MUTATE_RATE = 0.02 POP_SIZE = 500 N_GENERATOR = 500 class GA(object): def __init__(self, DNA_size, cross_rate, mutation_rate, pop_size): self.DNA_size = DNA_size self.cross_rate = cross_rate self.mutate_rate = mutation_rate self.pop_size = pop_size self.pop = np.vstack([np.random.permutation(DNA_size) for _ in range(pop_size)]) def translateDNA(self, DNA, city_position): line_x = np.empty_like(DNA, dtype=np.float64) line_y = np.empty_like(DNA, dtype=np.float64) for i, d in enumerate(DNA): city_coord = city_position[d] line_x[i,: ] = city_coord[:, 0] line_y[i,: ] = city_coord[:, 1] return line_x, line_y def get_fitness(self, lx, ly): total_distance = np.empty((lx.shape[0],), dtype=np.float64) for i, (x, y) in enumerate(zip(lx, ly)): total_distance[i] = np.sum(np.sqrt(np.square(np.diff(x))+np.square(np.diff(y)))) fitness = np.exp(self.DNA_size * 2 / total_distance) return fitness def select(self, fitnesss): idx = np.random.choice(np.arange(self.pop_size), size=self.pop_size, replace=True, p = fitnesss/fitnesss.sum()) return self.pop[idx] def cross_over(self, parent, pop): if np.random.rand() < self.cross_rate: i_ = np.random.randint(0, self.pop_size, size=1) # low , high, size cross_point = np.random.randint(0, 2, size=self.DNA_size).astype(bool) keep_city = parent[~cross_point] swap_city = pop[i_, np.isin(pop[i_].ravel(), keep_city, invert=True)] parent[:] = np.concatenate((keep_city, swap_city)) return parent def mutate(self, child): for point in range(self.DNA_size): if np.random.rand() < self.mutate_rate: swap_point = np.random.randint(0, self.DNA_size) swapA, swapB = child[point], child[swap_point] child[point], child[swap_point] = swapB, swapA return child def evolve(self, fitness): pop = self.select(fitness) pop_copy = pop.copy() for parent in pop: child = self.cross_over(parent, pop_copy) child = self.mutate(child) parent[:] = child self.pop = pop class env(object): def __init__(selfl, N_CITIES): selfl.city_postion = np.random.rand(N_CITIES, 2) def plotting(self, lx, ly, total_d): plt.cla() plt.scatter(self.city_postion[:, 0].T, self.city_postion[:,1].T, s=200, c='k') plt.plot(lx.T, ly.T, 'r-') plt.text(-0.05, -0.05, "Total fitness=%.2f" % total_d, fontdict={'size': 20, 'color': 'red'}) plt.xlim((-0.1, 1.1)) plt.ylim((-0.1, 1.1)) plt.pause(0.01) if __name__ == '__main__': ga = GA(DNA_size=N_CITIES, cross_rate=CROSS_RATE, mutation_rate=MUTATE_RATE, pop_size=POP_SIZE) envs = env(N_CITIES=N_CITIES) for generation in range(N_GENERATOR): lx, ly = ga.translateDNA(ga.pop, envs.city_postion) fitness = ga.get_fitness(lx, ly) ga.evolve(fitness) best_index = np.argmax(fitness) print('Gen:', generation, '| best fit: %.2f' % fitness[best_index], ) envs.plotting(lx[best_index], ly[best_index], fitness[best_index]) plt.ioff() plt.show()
遗传算法
最新推荐文章于 2022-08-21 01:07:29 发布