文章目录
项目介绍
本文将基于 Tensorflow2.0 使用遗传算法对网络结构寻优来对 cifar10 数据集进行分类。此项目共包含两个文件,一个是用于导入数据集、构建网络以及训练网络的文件 main.py,另一个是实现遗传算法的文件 GA.py。
在构建网络时,我们将采用模块堆叠的方式,仿照 MobileNet V2 中的模块结构;在训练网络时,我们将学习率和优化器类型视为可变参数与网络结构共同放入遗传算法进行寻优。
main.py
1、导入需要的库
import tensorflow as tf
import numpy as np
import os
2、导入数据集
我们将这一操作封装到 load() 函数中,方便 GA.py 文件的调用。
def load(