这两天写了一个比较通用的遗传算法框架common-geneticalgorithm,之所以写这个是因为以前每次需要用到遗传算法的时候总是手写一遍,从开始写代码起到现在估计至少写了不下20次了,像matlab,python,js,go,java的版本都写过。
主要思路是将重复的代码部分,像选择算子,精英保留,种群初始化,遗传算法的参数等写在框架里,剩下的交叉、变异、个体初始化,适应度计算部分使用抽象类代替,需要用户自己实现。
下面使用这个框架计算一个函数x2+3*sin(x*y*z)+y2-z^2的最大值以及对应的x,y,z的值,这个函数的图像如下:
这个函数在我的两篇博客里都计算过(https://blog.csdn.net/just_do_it_123/article/details/50993439,https://blog.csdn.net/just_do_it_123/article/details/50835793)
分别用粒子群算法和遗传算法寻找到的最优解如下:
在这里我用我写的这个框架来计算一下,看看结果怎么样,下面是代码:
package examples;
import algorithms.CommonGA;
import algorithms.Individual;
import com.zhangm.easyutil.Tuple;
/**
* 计算x^2+3*sin(x*y*z)+y^2-z^2的最大值,并求处于该值时x, y, z的值
* @Author zhangming
* @Date 2020/6/30 16:31
*/
public class SimpleIndividual extends Individual<double[], Double> {
public SimpleIndividual() {
double[] data = new double[] {Math.random() * 20 - 10, Math.random() * 20 - 10, Math.random() * 20 - 10};
this.setData(data);
this.setIndex(0.);
}
private double[][] bounds = new double[][] {new double[]{-10, 10}, new double[]{-10, 10}, new double[]{-10, 10}};
@Override
public SimpleIndividual mutate() {
SimpleIndividual _this = this.clone();
int dataLen = this.getData().length;
// 随机挑选一个值发生变异
int index = (int) Math.floor(Math.random() * dataLen);
if (Math.random() < 0.5) {
// 向下偏移0.1
_this.getData()[index] -= (this.bounds[index][1] - this.bounds[index][0]) * 0.1;
if (_this.getData()[index] < this.bounds[index][0]) {
_this.getData()[index] = this.bounds[index][0];
}
} else {
// 向上偏移0.1
_this.getData()[index] += (this.bounds[index][1] - this.bounds[index][0]) * 0.1;
if (_this.getData()[index] > this.bounds[index][1]) {
_this.getData()[index] = this.bounds[index][1];
}
}
return _this;
}
@Override
public void calculateIndex() {
// 计算指标
double x = this.getData()[0];
double y = this.getData()[1];
double z = this.getData()[2];
this.setIndex(x * x + 3 * Math.sin(x * y * z) + y * y - z * z);
}
@Override
public SimpleIndividual clone() {
double[] data = new double[3];
System.arraycopy(this.getData(), 0, data, 0, 3);
double index = this.getIndex();
SimpleIndividual newIndividual = new SimpleIndividual();
newIndividual.setData(data);
newIndividual.setIndex(index);
return newIndividual;
}
public static void main(String[] args) {
CommonGA.CrossFunction<double[], Double> crossFunction = (Individual<double[], Double> individual1, Individual<double[], Double> individual2) -> {
int index = (int) Math.floor(Math.random() * 3);
Individual<double[], Double> newIndividual1 = individual1.clone();
Individual<double[], Double> newIndividual2 = individual2.clone();
newIndividual1.getData()[index] = individual2.getData()[index];
newIndividual2.getData()[index] = individual1.getData()[index];
return new Tuple<>(newIndividual1, newIndividual2);
};
CommonGA.InitFunction<double[], Double> initFunction = SimpleIndividual::new;
// 由于算法以小为优,此处取最大值
CommonGA.ComparatorFunction<double[], Double> comparatorFunction = () -> (individual1, individual2) ->
individual1.getIndex() < individual2.getIndex() ? 1 : individual1.getIndex().equals(individual2.getIndex()) ? 0 : -1;
CommonGA<double[], Double> ga = CommonGA.of(SimpleIndividual.class, crossFunction, comparatorFunction, initFunction)
.withGeneralSize(200).withPopSize(100);
ga.start();
Individual<double[], Double> bestIndividual = ga.getBestIndividual();
System.out.printf("max: %f, x: %f, y: %f, z: %f\n", bestIndividual.getIndex(),
bestIndividual.getData()[0], bestIndividual.getData()[1], bestIndividual.getData()[2]);
}
}
结果:
max: 202.686378, x: -10.000000, y: -10.000000, z: -0.302324
可以看到这个结果和之前求解的值很接近,使用这个框架只需要自己实现个体初始化,变异,适应度值计算,个体复制,交叉,适应度比较这几个部分的代码,剩下的框架就可以进行计算了,还是很方便的。项目地址:
https://github.com/Mng12345/common-geneticalgorithm
欢迎大家来star一下哈,这个框架还是非常值得使用和学习的,其中关于泛型部分的设计我前前后后改了好几遍,终于像个样子了。对了,项目里有一个叫easy-util的依赖是我自己的工具包,用来操作数据很方便,项目地址见
https://github.com/Mng12345/easy-util/tree/master
为了增强求解效率,使用parallelStream来计算交叉算子
private void cross() {
List<Tuple<Integer, Integer>> crossedIndexes = new ArrayList<>();
for (int i=0; i<this.popSize; i++) {
if (Math.random() < this.pCross) {
// 从种群中随机挑选出两个不重复的个体
int[] index = new int[2];
RangeUtil.randomNotRepeat(this.individuals, 2, index);
crossedIndexes.add(new Tuple<>(index[0], index[1]));
}
}
crossedIndexes.parallelStream().forEach(twoCrossIndex -> {
Tuple<Individual<T1, T2>, Individual<T1, T2>> crossedIndividuals = this.crossFunction.apply(
this.individuals.get(twoCrossIndex.getV1()), this.individuals.get(twoCrossIndex.getV2()));
this.individuals.set(twoCrossIndex.getV1(), crossedIndividuals.getV1());
this.individuals.set(twoCrossIndex.getV2(), crossedIndividuals.getV2());
});
Streams.forEachIndexedParallel(this.individuals.stream(),
(individual, index) -> this.individuals.get(index.intValue()).calculateIndex());
}
欢迎大家关注我的公众号