机器学习课程结束,作为遗传算法小白,简单实现了遗传算法解决PlayTennis问题,参考汤姆.米切尔的《机器学习》。本人认为这个问题最大的难点在于如何设计适应度函数。
首先把天气状况表示为二进制位串
Outlook | Sunny | 100 |
Overcast | 010 | |
Rain | 001 | |
Humidity | High | 10 |
Normal | 01 | |
Wind | Strong | 10 |
Weak | 01 | |
PlayTennis | Yes | 1 |
No | 0 |
GA(Fitness,Fitness_threshold,p,r,m)
Fitness:适应度评分函数,为给定假设赋予一个评估分数
Fitness_Threshold:指定终止判据的阈值
p:群体中包含的假设数量
r:每一步中通过交叉取代群体成员的比例
m:变异率
- 初始化群体:p(随机产生的p个假设)
- 评估:对于p中的每个h,计算Fitness(h)
- 当[maxFitness(h)]< Fitness_Threshold,做:
产生新一代Ps:
(1) 选择:用概率方法选择p的(1-r)p个成员加入Ps。
(2) 交叉:按概率选择r*p/2对假设,对每对假设<h1,h2>,应用交叉算子产生两个后代。把所有的后代加入Ps。
(3) 变异:使用均匀的概率从Ps中选择m%的成员。对于选出的每个成员,在它的表示中随机选择一个位取反。
(4) 更新p<-Ps
- 从p中返回适应度最高的假设
主要代码public class Main {
static List<String> pool = new ArrayList<>();//假设池
static List<String> newPool = new ArrayList<>();//子代
static int GEN = 5;//一共演化5代
static float p_c = (float) 0.5;//交叉率
static float p_y = (float) 0.5;//变异率
public static void main(String[] args) {
Random rand = new Random();
//初始化
for(int i=0;i<300;i++){
String code = "";
for(int j=0;j<8;j++)
code += rand.nextInt(2);
pool.add(code);
}
System.out.println("--------初始化--------");
print(pool);
//演化
for(int i=0;i<GEN;i++){
System.out.println("--------演化"+i+"--------");
evolution(pool);
}
//得到最佳假设
String bestChoise = pool.get(0);
float bestFitness = calculateFitness(pool.get(0));
for(int i=0;i<pool.size()-1;i++){
if(calculateFitness(pool.get(i))<calculateFitness(pool.get(i+1))){
bestChoise = pool.get(i+1);
bestFitness =calculateFitness(pool.get(i+1));
}
}
System.out.println("--------最佳假设--------");
System.out.println(bestChoise);
System.out.println("--------适应度--------");
System.out.println(bestFitness);
}
/**
* 适应度计算
* @param code
* @return
*/
private static float calculateFitness(String code) {
int accCount = 0;
int wrongCount = 0;
for (String data : trainDatas) {
int equalCount = 0;
for(int i=0;i<data.length()-1;i++){ //剔除最后一位的Yes或No
if(data.charAt(i)=='1' &&code.charAt(i)=='1')
equalCount++;
}
if(equalCount==3)
if(data.charAt(data.length()-1)==code.charAt(code.length()-1))
accCount++;
else
wrongCount++;
}
if(accCount == 0)
return 0;
float acc = (float)accCount/(accCount+wrongCount);
if(code.charAt(code.length()-1)=='0'){
float recall = (float)accCount/4;
return 2*acc*recall/(acc+recall);
}else{
float recall = (float)accCount/8;
return 2*acc*recall/(acc+recall);
}
}
/**
* 演化
* @param pool
*/
private static void evolution(List<String> pool){
//计算适应度
Random rand = new Random();
List<Float> fitness = new ArrayList<>();
fitness.add(0f);
for(int i=0;i<pool.size();i++){
fitness.add(
fitness.get(i)+calculateFitness(pool.get(i))
);
}
fitness.remove(0);
//轮盘
for(int i=0;i<20;i++){//选20个
float pick = rand.nextFloat()*fitness.get(fitness.size()-1);
boolean selectDone = false;
float selectFit = 0;
for(int j=0; j<fitness.size(); j++){
if(selectDone)
fitness.set(j,fitness.get(j)-selectFit);
if(!selectDone &&pick<=fitness.get(j)){
newPool.add(pool.get(j));
if(j==0){
selectFit = fitness.get(j);
}else{
selectFit = fitness.get(j) -fitness.get(j-1);
}
fitness.remove(j);
pool.remove(j);
j--;
selectDone = true;
}
}
}
pool.clear();
pool.addAll(newPool);
newPool.clear();
System.out.println("--------轮盘赌--------");
print(pool);
//交叉
for(int i=0;i<pool.size()-1;i++){
for(int j=i+1;j<pool.size();j++){
if(pool.get(i).charAt(7)==pool.get(j).charAt(7)){
cross(pool.get(i),pool.get(j));
}else{
newPool.add(pool.get(i));
newPool.add(pool.get(j));
}
}
}
pool.clear();
pool.addAll(newPool);
newPool.clear();
System.out.println("--------交叉--------");
print(pool);
//变异
for(int i=0;i<pool.size();i++){
float v = rand.nextFloat();
if(v<p_y){
variation(pool.get(i));
}else{
newPool.add(pool.get(i));
}
}
pool.clear();
pool.addAll(newPool);
newPool.clear();
System.out.println("--------变异--------");
print(pool);
}
/**
* 交叉
* @param parent1
* @param parent2
*/
private static void cross(String parent1,String parent2){
Random rand = new Random();
StringBuilder child1 = new StringBuilder();
StringBuilder child2 = new StringBuilder();
float cro = rand.nextFloat();
if(cro>p_c){
child1.append(parent2.substring(0,3));
child2.append(parent1.substring(0,3));
}else{
child1.append(parent1.substring(0,3));
child2.append(parent2.substring(0,3));
}
cro = rand.nextFloat();
if(cro>p_c){
child1.append(parent2.substring(3,5));
child2.append(parent1.substring(3,5));
}else{
child1.append(parent1.substring(3,5));
child2.append(parent2.substring(3,5));
}
cro = rand.nextFloat();
if(cro>p_c){
child1.append(parent2.substring(5,7));
child2.append(parent1.substring(5,7));
}else{
child1.append(parent1.substring(5,7));
child2.append(parent2.substring(5,7));
}
cro = rand.nextFloat();
if(cro>p_c){
child1.append(parent2.substring(7,8));
child2.append(parent1.substring(7,8));
}else{
child1.append(parent1.substring(7,8));
child2.append(parent2.substring(7,8));
}
newPool.add(child1.toString());
newPool.add(child2.toString());
}
/**
* 变异,采用点变异
* @param children
*/
private static void variation(String parent){
Random rand = new Random();
int location = rand.nextInt(parent.length()-2);//变异位置,对最后一位不进行变异
char v = parent.charAt(location);
String children;
if(v == 0){
children = parent.substring(0,location)+"1"+parent.substring(location+1, parent.length());
}else{
children = parent.substring(0,location)+"0"+parent.substring(location+1, parent.length());
}
newPool.add(children);
}
/**
* 输出
* @param pool
*/
private static void print(List<String> pool){
for(int i=0;i<pool.size();i++){
System.out.println(pool.get(i));
}
}
private static String[] trainDatas = {
"10010010",
"10010100",
"10001011",
"10001101",
"01001011",
"01010011",
"01001101",
"01010101",
"00101100",
"00110100",
"00101011",
"00110011",
};
}