使用遗传算法求解函数具有最大值的点X
"""
Visualize Genetic Algorithm to find a maximum point in a function.
"""
import numpy as np
import matplotlib. pyplot as plt
DNA_SIZE = 10
POP_SIZE = 100
CROSS_RATE = 0.8
MUTATION_RATE = 0.003
N_GENERATIONS = 200
X_BOUND = [ 0 , 5 ]
def F ( x) :
return np. sin( 10 * x) * x + np. cos( 2 * x) * x
def get_fitness ( pred) :
return pred + 1e - 3 - np. min ( pred)
def select ( pop, fitness) :
idx = np. random. choice( np. arange( POP_SIZE) , size= POP_SIZE, replace= True , p= fitness / fitness. sum ( ) )
return pop[ idx]
def crossover ( parent, pop) :
if np. random. rand( ) < CROSS_RATE:
i_ = np. random. randint( 0 , POP_SIZE, size= 1 )
cross_points = np. random. randint( 0 , 2 , size= DNA_SIZE) . astype( np. bool )
parent[ cross_points] = pop[ i_, cross_points]
return parent
def mutate ( child) :
for point in range ( DNA_SIZE) :
if np. random. rand( ) < MUTATION_RATE:
child[ point] = 1 if child[ point] == 0 else 0
return child
def translateDNA ( pop) :
return pop. dot( 2 ** np. arange( DNA_SIZE) [ : : - 1 ] ) / float ( 2 ** DNA_SIZE - 1 ) * X_BOUND[ 1 ]
pop = np. random. randint( 2 , size= ( POP_SIZE, DNA_SIZE) )
plt. ion( )
x = np. linspace( * X_BOUND, 200 )
plt. plot( x, F( x) )
for _ in range ( N_GENERATIONS) :
F_values = F( translateDNA( pop) )
if 'sca' in globals ( ) : sca. remove( )
sca = plt. scatter( translateDNA( pop) , F_values, s= 200 , lw= 0 , c= 'red' , alpha= 0.5 ) ;
plt. pause( 0.05 )
fitness = get_fitness( F_values)
print ( "Most fitted DNA: " , pop[ np. argmax( fitness) , : ] )
pop = select( pop, fitness)
pop_copy = pop. copy( )
for parent in pop:
child = crossover( parent, pop_copy)
child = mutate( child)
parent[ : ] = child
print ( "the X: " , translateDNA( pop[ np. argmax( fitness) , : ] ) )
plt. ioff( ) ;
plt. show( )