python:neat-python遗传拓扑神经网络初步使用

neat-python是对遗传算法的拓展,结合了神经网络和遗传算法的模型,本文旨在使用不在讲解。

本文利用neat计算x²来进行一个neat的初步使用,通过本文教程可以大致使用neat完成一些基础任务。

安装neat

pip install neat-python

一、配置config文件

neat利用config文件来进行参数配置,具体的配置文件模板可以在官方文档中找到:配置文件模板地址。

1.1 模板

[NEAT]
fitness_criterion     = max
fitness_threshold     = 3.9
pop_size              = 50
reset_on_extinction   = False

[DefaultGenome]
# node activation options
activation_default      = square
activation_mutate_rate  = 0.0
activation_options      = square

# node aggregation options
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

# node bias options
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 30.0
bias_min_value          = -30.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1

# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

# connection add/remove rates
conn_add_prob           = 0.5
conn_delete_prob        = 0.5

# connection enable options
enabled_default         = True
enabled_mutate_rate     = 0.01

feed_forward            = True
initial_connection      = full

# node add/remove rates
node_add_prob           = 0.2
node_delete_prob        = 0.2

# network parameters
num_hidden              = 0
num_inputs              = 1
num_outputs             = 1

# node response options
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 30.0
response_min_value      = -30.0
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

# connection weight options
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 30
weight_min_value        = -30
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2

文件取名为config-feedforward.txt。
本文将不阐述每一个参数的内容,仅介绍将会使用的一些参数。

1.2.1 选择激活函数

在本文中,我们需要计算x²,所以首先在脑海中需要明白我们的图形产出是一个二次方程图形:
二次方程
所以我们选择了:

activation_default      = square
activation_mutate_rate  = 0.0
activation_options      = square

其他的激活函数可以查看官方文档:激活函数列表。
在使用neat时,需要明白我们输出的结果是什么,比如本文,我们知道结果是平方,所以用square;如果需要输出是或否,那可以选择sigmoid。

1.2.2 输入输出

在本文中,我们只有一个x,得到的结果只有一个y,例如2输出4,3输出9,所以在这部分我们设置为:

num_hidden              = 0
num_inputs              = 1
num_outputs             = 1

我们没有隐藏层,如果需要隐藏层,可以设置num_hidden。

1.2.3 种群数量及适应值阈值

本文中,将每一个族群设置为50个,阈值设为3.9,表示每一次计算会有50个子代,当适应值达到3.9,将停止迭代:

fitness_threshold     = 3.9
pop_size              = 50

1.2.4 其他参数

我们需要设置的参数绝大部分都是上面的内容,其他参数特殊情况特殊配置,一般使用情况下可以直接使用默认参数值。

二、编写neat

local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'config-feedforward.txt') # 配置文件名称
config = neat.config.Config(
        neat.DefaultGenome,
        neat.DefaultReproduction,
        neat.DefaultSpeciesSet,
        neat.DefaultStagnation,
        config_path
    )# 配置文件中每一组
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)

winner = p.run(eval_genomes, 1000) # 定义一个eval_genomes方法进行迭代,总共迭代100次

2.1 eval_genomes函数

def eval_genomes(genomes, config):
    # x ** 2
    x_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 样本
    y_list = [1, 4, 9, 16, 25, 36, 49, 64, 81, 100] # 样本结果

    for genome_id, genome in genomes: # 循环50个子代
        genome.fitness = 4.0 # 初始化适应度
        net = neat.nn.FeedForwardNetwork.create(genome, config) # 创建前馈网络
        for i, x in enumerate(x_list):
            output = net.activate((x, )) # 计算
            genome.fitness -= (output[0] - y_list[i]) ** 2 # 使用误差来进行适应值的降低

当进行p.run函数时,将进行eval_genomes函数的调用,在eval_genomes函数中,适当地给予每一个子代(基因)适应度的调整,来选择合适的子代,当子代达到我们的适应值阈值时,将会停止迭代。

三、运行

Population's average fitness: -42724756.48750 stdev: 220655177.29976
Best fitness: 3.98619 - size: (1, 1) - species 1 - id 3838

Best individual in generation 87 meets fitness threshold - complexity: (1, 1)

本次运行结果到第87代已经产生最佳适应度为3.98612的子代,我们可以通过保存计算结果来进行使用。

四、保存计算结果

4.1 保存

pickle.dump(winner, open("best.pickle", "wb"))

4.2 调用

path = "xx/xx/xx/best.pickle"
    with open(path, "rb") as f:
        net = pickle.load(f)

local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'config-feedforward.txt')
config = neat.config.Config(
        neat.DefaultGenome,
        neat.DefaultReproduction,
        neat.DefaultSpeciesSet,
        neat.DefaultStagnation,
        config_path
    )

net = neat.nn.FeedForwardNetwork.create(net, config)
output = net.activate((4,))
print(output[0])

输出结果为:

15.842392247827298

可以看到,结果已经是我们想要的了。

五、完整代码

# -*- coding: utf-8 -*-
import os
import neat
import pickle


def eval_genomes(genomes, config):
    # x ** 2
    x_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    y_list = [1, 4, 9, 16, 25, 36, 49, 64, 81, 100]

    for genome_id, genome in genomes:
        genome.fitness = 4.0
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        for i, x in enumerate(x_list):
            output = net.activate((x, ))
            genome.fitness -= (output[0] - y_list[i]) ** 2


if __name__ == '__main__':
    path = "xx/xx/xx/best.pickle"
    with open(path, "rb") as f:
        net = pickle.load(f)

    local_dir = os.path.dirname(__file__)
    config_path = os.path.join(local_dir, 'config-feedforward.txt')
    config = neat.config.Config(
        neat.DefaultGenome,
        neat.DefaultReproduction,
        neat.DefaultSpeciesSet,
        neat.DefaultStagnation,
        config_path
    )

    net = neat.nn.FeedForwardNetwork.create(net, config)
    output = net.activate((4,))
    print(output[0])

    # local_dir = os.path.dirname(__file__)
    # config_path = os.path.join(local_dir, 'config-feedforward.txt')
    #
    # config = neat.config.Config(
    #     neat.DefaultGenome,
    #     neat.DefaultReproduction,
    #     neat.DefaultSpeciesSet,
    #     neat.DefaultStagnation,
    #     config_path
    # )
    #
    # p = neat.Population(config)
    #
    # p.add_reporter(neat.StdOutReporter(True))
    # stats = neat.StatisticsReporter()
    # p.add_reporter(stats)
    #
    # winner = p.run(eval_genomes, 1000)
    # pickle.dump(winner, open("best.pickle", "wb"))


  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值