简单的遗传算法(Genetic algorithms)-吃豆人

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013564276/article/details/53470049

遗传算法简介:

一直都在收听卓老板聊科技这个节目,最近播出了一起人工智能的节目,主要讲的是由霍兰提出的遗传算法,在目中详细阐述了一个有趣的小实验:吃豆人。

首先简单介绍下遗传算法:
1:为了解决某个具体的问题,先随机生成若干个解决问题的实体,每个实体解决问题的方式都用“基因”来表示,也就是说,不同的实体拥有不同的基因,那么也对应着不同的解决问题的方案。
2:有了若干实体之后,接下来就是让这些实体来完成这个任务,根据任务的完成情况用相同标准打分。
3:接下来是进化环节,按照得分的高低,得出每个个体被选出的概率,得分越高越容易被选出,先选出两个个体,对其基因进行交叉,再按照设定的概率对其基因进行突变,来生成新个体,不停重复直到生成足够数量的新个体,这便是一次进化过程。按照这个方法不停的进化,若干代之后就能得到理想的个体。

下面简单介绍下吃豆人实验:

吃豆人首先生存在一个10*10个格子组成的矩形空间中,将50个豆子随机放在这100个格子中,每个格子要嘛为空,要嘛就有一颗豆子。吃豆人出生的时候随机出现在一个任意方格中,接下来吃豆人需要通过自己的策略来吃豆子,一共只有200步,吃到一颗+10分,撞墙-5分,发出吃豆子的动作却没吃到豆子-1分。另外吃豆人只能看到自己所在格子和上下左右一共5个格子的情况。

整理一下
吃豆人的所有动作:上移、下移、左移、由移、吃豆、不动、随机移动,一共7种
吃豆人所能观察到的状态:每个格子有,有豆子,无豆子,墙3种状态,而一共有5个格子,那就是3^5=243种状态。

为此,吃豆人个体的基因可以用243长度的基因表示,分别对应所处的243种状态,每个基因有7种情况,分别表示所处状态下产生的反应。

代码

Main.java

public class Main {
    public static void main(String[] args) {
        Population population = new Population(1000, false);
        System.out.println(population);
        long count = 1;
        while (true){                           
            Population newPopulation = Algorithm.evolve(population);
            if (count % 5 == 0) {
                System.out.println("The " + count + "'s evolve");
                System.out.println(newPopulation);  
            }
            population = newPopulation;
            count++;            
        }
    }       
}

Individual.java

public class Individual {

    //吃豆人一共会有3^5种状态,它能观察的位置一共有上下左右和当前格子,一个共5个,每个格子有墙,豆子,无豆子3种状态。
    private static int length = 243;
    /*吃豆人一共有7总动作
     * 0 :上    4 : 随机移动
     * 1 : 左   5 : 吃
     * 2 : 下   6 : 不动  
     * 3 : 右    
    */
    private static byte actionNum = 7;

    private byte genes[] = null;
    private int fitness = Integer.MIN_VALUE;

    public Individual() {
        genes = new byte[length];       
    }

    public void generateGenes(){        
        for (int i = 0; i < length; i++) {
            byte gene = (byte) Math.floor(Math.random() * actionNum);
            genes[i] = gene;
        }
    }

    public int getFitness() {
        if (fitness == Integer.MIN_VALUE) {
            fitness = FitnessCalc.getFitnessPall(this);
        }
        return fitness;
    }


    public int getLength() {
        return length;
    }

    public byte getGene(int index) {
        return genes[index];
    }

    public void setGene(int index, byte gene) {
        this.genes[index] = gene;
        fitness = Integer.MIN_VALUE;
    }

    //状态码的转换:5个3进制位,第一个代表中,第二个代表上,第三个代表右,第四个代表下,第五个代表左
    public byte getActionCode(State state) {        
        int stateCode = (int) (state.getMiddle() * Math.pow(3, 4) + state.getUp() * Math.pow(3, 3) + state.getRight() * Math.pow(3, 2) + state.getDown() * 3 + state.getLeft());
        return genes[stateCode];
    }

    @Override
    public String toString() {  
        StringBuffer bf = new StringBuffer();
        for (int i = 0; i < length; i++) {
            bf.append(genes[i]);
        }
        return bf.toString();
    }

    public static void main(String[] args) {
        Individual ind = new Individual();
        ind.generateGenes();
        System.out.println(ind);
        System.out.println(ind.getFitness());
        System.out.println(FitnessCalc.getFitnessPall(ind));
    }
}

State.java

public class State {
    //0为墙,1为有豆子,2为无豆子   
    private byte middle;
    private byte up;
    private byte right;
    private byte down;
    private byte left;

    public State(byte middle, byte up, byte right, byte down, byte left) {
        this.middle = middle;
        this.up = up;
        this.right = right;
        this.down = down;
        this.left = left;
    }

    public byte getMiddle() {
        return middle;
    }

    public void setMiddle(byte middle) {
        this.middle = middle;
    }

    public byte getUp() {
        return up;
    }

    public void setUp(byte up) {
        this.up = up;
    }

    public byte getRight() {
        return right;
    }

    public void setRight(byte right) {
        this.right = right;
    }

    public byte getDown() {
        return down;
    }

    public void setDown(byte down) {
        this.down = down;
    }

    public byte getLeft() {
        return left;
    }

    public void setLeft(byte left) {
        this.left = left;
    }


}

Algorithm.java

public class Algorithm {
    /* GA 算法的参数 */
    private static final double uniformRate = 0.5; //交叉概率
    private static final double mutationRate = 0.0001; //突变概率
    private static final int tournamentSize = 3; //淘汰数组的大小

    public static Population evolve(Population pop) {
        Population newPopulation = new Population(pop.size(), true);

        for (int i = 0; i < pop.size(); i++) {
        //随机选择两个 优秀的个体
            Individual indiv1 = tournamentSelection(pop);
            Individual indiv2 = tournamentSelection(pop);           
            //进行交叉
            Individual newIndiv = crossover(indiv1, indiv2);
            newPopulation.saveIndividual(i, newIndiv);  
        }

        // Mutate population  突变
        for (int i = 0; i < newPopulation.size(); i++) {
            mutate(newPopulation.getIndividual(i));
        }   
        return newPopulation;       
    }       

    // 随机选择一个较优秀的个体,用了进行交叉
    private static Individual tournamentSelection(Population pop) {
        // Create a tournament population
        Population tournamentPop = new Population(tournamentSize, true);
        //随机选择 tournamentSize 个放入 tournamentPop 中
        for (int i = 0; i < tournamentSize; i++) {
            int randomId = (int) (Math.random() * pop.size());
            tournamentPop.saveIndividual(i, pop.getIndividual(randomId));
        }
        // 找到淘汰数组中最优秀的
        Individual fittest = tournamentPop.getFittest();
        return fittest;
    }

    // 进行两个个体的交叉 。 交叉的概率为uniformRate
    private static Individual crossover(Individual indiv1, Individual indiv2) {
        Individual newSol = new Individual();
        // 随机的从 两个个体中选择 
        for (int i = 0; i < indiv1.getLength(); i++) {
            if (Math.random() <= uniformRate) {
                newSol.setGene(i, indiv1.getGene(i));
            } else {
                newSol.setGene(i, indiv2.getGene(i));
            }
        }
        return newSol;
    }

    // 突变个体。 突变的概率为 mutationRate
    private static void mutate(Individual indiv) {
        for (int i = 0; i < indiv.getLength(); i++) {
            if (Math.random() <= mutationRate) {
                // 生成随机的 0-6
                byte gene = (byte) Math.floor(Math.random() * 7);
                indiv.setGene(i, gene);
            }
        }
    }
}

Population.java

public class Population {

    private Individual[] individuals;

    public Population(int size, boolean lazy) {
        individuals = new Individual[size];
        if (!lazy) {
            for (int i = 0; i < individuals.length; i++) {
                Individual ind = new Individual();
                ind.generateGenes();
                individuals[i] = ind;
            }
        }
    }

    public void saveIndividual(int index, Individual ind) {
        individuals[index] = ind;
    }

    public Individual getIndividual(int index) {
        return individuals[index];
    }

    public Individual getFittest() {
        Individual fittest = individuals[0];
        // Loop through individuals to find fittest
        for (int i = 1; i < size(); i++) {
            if (fittest.getFitness() <= getIndividual(i).getFitness()) {
                fittest = getIndividual(i);
            }
        }
        return fittest;
    }

    public Individual getLeastFittest() {
        Individual ind = individuals[0];
        for (int i = 1; i < size(); i++) {
            if (ind.getFitness() > getIndividual(i).getFitness()) {
                ind = getIndividual(i);
            }
        }
        return ind;
    }

    public double getAverageFitness() {
        double sum = 0;
        for (int i = 0; i < size(); i++) {
            sum += individuals[i].getFitness();
        }
        return sum / size();
    }

    public int size() {
        return individuals.length;
    }

    @Override
    public String toString(){
        StringBuffer bf = new StringBuffer();
        bf.append("Population size: " + size() + "\n");
        bf.append("Max Fitnewss: " + getFittest().getFitness() + "\n");
        bf.append("Least Fitness: " + getLeastFittest().getFitness() + "\n");
        bf.append("Average Fitness: " + getAverageFitness() + "\n");        
        return bf.toString();
    }

    public static void main(String[] args) {
        Population population = new Population(8000, false);
        System.out.println(population);    
    }
}

MapMgr.java

public class MapMgr {

    private static int x = 10;
    private static int y = 10;
    private static int beanNum = 50;
    private static int mapNum = 100;

    private static MapMgr manager = null;       
    private Map[] maps = null;

    private MapMgr() {
        maps = new Map[mapNum];
        for (int i = 0; i < mapNum; i++) {
            Map map = new Map(x, y);
            map.setBeans(beanNum);
            maps[i] = map;
        }
    }

    synchronized public static MapMgr getInstance() {
        if (manager == null) manager = new MapMgr();
        return manager;
    }

    public Map getMap(int index) {
        Map map = null;
        index = index % mapNum;
        try {
            map = maps[index].clone();
        } catch (CloneNotSupportedException e) {
            e.printStackTrace();
        }
        return map;     
    }

    public static void main(String[] args) {
        MapMgr mgr = MapMgr.getInstance();
        mgr.getMap(1).print();
        System.out.println("--------------");
        mgr.getMap(2).print();
    }
}

Map.java

import java.awt.Point;

public class Map implements Cloneable{

    private int x = -1;
    private int y = -1;
    private int total = -1;
    private byte[][] mapGrid = null;

    public Map(int x, int y) {
        this.x = x;
        this.y = y;
        mapGrid = new byte[x][y];
        total = x * y;
    }

    public void setBeans(int num) {
        //check num 
        if (num > total) {
            num = total;
        }
        for (int i = 0; i < num; i++) {
            int address, xp, yp;
            do{
                address = (int) Math.floor((Math.random() * total)); //生成0 - (total-1)的随机数          
                xp = address / y;
                yp = address % y;   
                //System.out.println(xp+ ":" + yp + ":" + address + ":" + total);
            } while (mapGrid[xp][yp] != 0);
            mapGrid[xp][yp] = 1;            
        }

    }

    public boolean isInMap(int x, int y) {      
        if (x < 0 || x >= this.x) return false;
        if (y < 0 || y >= this.y) return false;     
        return true;
    }

    public boolean hasBean(int x, int y) {
        boolean ret = mapGrid[x][y] == 0 ? false : true;
        return ret;
    }

    public boolean eatBean(int x, int y) {
        if(hasBean(x, y)) {
            mapGrid[x][y] = 0;
            return true;
        }
        return false;
    }

    public Point getStartPoint() {              
        int x = (int) Math.floor(Math.random() * this.x);
        int y = (int) Math.floor(Math.random() * this.y);       
        return new Point(x, y);
    }

    public State getState(Point p) {        
        byte middle = stateOfPoint(p);
        byte up = stateOfPoint(new Point(p.x, p.y - 1));
        byte right = stateOfPoint(new Point(p.x + 1, p.y));
        byte down = stateOfPoint(new Point(p.x, p.y + 1));
        byte left = stateOfPoint(new Point(p.x - 1, p.y));
        return new State(middle, up, right, down, left);
    }

    //0为墙,1为有豆子,2为无豆子
    private byte stateOfPoint(Point p) {
        byte ret;

        if (!isInMap(p.x, p.y)) ret = 0;            
        else if (mapGrid[p.x][p.y] == 0) ret =  2;
        else ret = 1;

        return ret;
    }


    @Override
    public Map clone() throws CloneNotSupportedException {
        Map m = (Map) super.clone();
        byte[][] mapGrid = new byte[x][y];
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                mapGrid[i][j] = this.mapGrid[i][j];
            }
        }
        m.mapGrid = mapGrid;
        return m;       
    }

    public void print() {
        for (int i = 0; i < y; i++) {
            for (int j = 0; j < x; j++) {
                System.out.print(mapGrid[j][i]);
            }
            System.out.println();
        }
    }

    public static void main(String[] args) {
        Map m = new Map(10, 5);
        Map m1 = null;
        try {
            m1 = m.clone();
        } catch (CloneNotSupportedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        m.setBeans(40);
        m.print();
        m1.setBeans(15);
        m1.print();
    }

}

FitnessCalc

import java.awt.Point;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;

public class FitnessCalc {
    /*动作结果说明:
     * 撞墙:-5分
     * 吃到豆子:10分
     * 吃空了:-1分
     * 其他:0分
     */ 
    //模拟进行的场数
    private static int DefaultSimTimes = 1000;
    //模拟进行的步数
    private static int simSteps = 200;
    private static int cores = 4;

    public static int getFitness(Individual ind) {
        return getFitness(ind, DefaultSimTimes);
    }

    public static int getFitness(Individual ind, int simTimes) {
        int fitness = 0;        
        MapMgr mgr = MapMgr.getInstance();  
        for (int i = 0; i < simTimes; i++) {
            Map map = mgr.getMap(i);
            Point point = map.getStartPoint();  
            for (int j = 0; j < simSteps; j++) {
                State state = map.getState(point);
                byte actionCode = ind.getActionCode(state);
                fitness += action(point, map, actionCode);
                //map.print();
                //System.out.println("---");
            }                               
        }       
        return fitness / simTimes;
    }

    public static int getFitnessPall(Individual ind) {
        int fitness = 0;        
        if (DefaultSimTimes < 100) {
            fitness = getFitness(ind);
        } else {                            
            FutureTask<Integer>[] tasks = new FutureTask[cores];            
            for (int i = 0; i < cores; i++) {
                FitnessPall pall = null;
                if (i == 0) {
                    pall = new FitnessPall(ind, (DefaultSimTimes / cores) + DefaultSimTimes % cores);
                } else {
                    pall = new FitnessPall(ind, DefaultSimTimes / cores);   
                }               
                tasks[i] = new FutureTask<Integer>(pall);
                Thread thread = new Thread(tasks[i]);
                thread.start();
            }       
            for (int i = 0; i < cores; i++) {
                try {
                    fitness += tasks[i].get();
                } catch (InterruptedException | ExecutionException e) {
                    e.printStackTrace();
                }
            }
            fitness = fitness / cores;
        }
        return fitness;
    }


    private static int action(Point point, Map map, int actionCode) {
        int sorce = 0;
        switch (actionCode) {
        case 0:
            if (map.isInMap(point.x, point.y - 1)) {
                sorce = 0;
                point.y = point.y - 1;
            } else {
                sorce = -5;
            }           
            break;
        case 1:
            if (map.isInMap(point.x - 1, point.y)) {
                sorce = 0;
                point.x = point.x - 1;
            } else {
                sorce = -5;
            }
            break;
        case 2:
            if (map.isInMap(point.x, point.y + 1)) {
                sorce = 0;
                point.y = point.y + 1;
            } else {
                sorce = -5;
            }
            break;
        case 3: 
            if (map.isInMap(point.x + 1, point.y)) {
                sorce = 0;
                point.x = point.x + 1;
            } else {
                sorce = -5;
            }
            break;
        case 4:
            int randomCode = (int) Math.floor(Math.random() * 4);
            sorce = action(point, map, randomCode);         
            break;
        case 5:
            if (map.eatBean(point.x, point.y)) {
                sorce = 10;             
            } else {
                sorce = -1;
            }
            break;
        case 6: 
            sorce = 0;
            break;
        }
        return sorce;
    }


}

class FitnessPall implements Callable<Integer> {
    private int simTimes;
    private Individual ind;
    public FitnessPall(Individual ind, int simTimes) {
        this.ind = ind;
        this.simTimes = simTimes;       
    }

    @Override
    public Integer call() throws Exception {
        return FitnessCalc.getFitness(ind, simTimes);       
    }   
}
阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页