优化算法学习笔记

本博客文章(学习笔记)导航 (点击这里访问)
在这里插入图片描述

一 遗传算法

1 遗传算法介绍

	遗传算法(Genetic Algorithm)遵循『适者生存』、『优胜劣汰』的原则,是一类借鉴生物界自然选择和自然遗传机制的随机化搜索算法。
	遗传算法模拟一个人工种群的进化过程,通过选择(Selection)、交叉(Crossover)以及变异(Mutation)等机制,在每次迭代中都保留一组候选个体,重复此过程,种群经过若干代进化后,理想情况下其适应度达到近似最优的状态。
	自从遗传算法被提出以来,其得到了广泛的应用,特别是在函数优化、生产调度、模式识别、神经网络、自适应控制等领域,遗传算法发挥了很大的作用,提高了一些问题求解的效率

2 遗传算法的步骤

image.png

3 遗传算法的Java例子

image.png

3.1 普通函数(基本遗传算法)

1 编码与解码
编码:初始化,将染色体的每一位都随机置为01
//初始化群体(编码)
public class ClsInit {
	//初始化一条染色体
	public String initSingle(int GENE){
		String res = "";
		for(int i = 0; i < GENE; i++){
			if(Math.random() < 0.5){
				res += 0;
			}else{
				res +=1;
			}
		}
		return res;
	}
 
	//初始化一组染色体
	public String[] initAll(int GENE,int groupsize){
		String[] iAll = new String[groupsize];
		for(int i = 0; i < groupsize; i++){
			iAll[i] = initSingle(GENE);
		}
		return iAll;
	}
}

解码时,先将一条染色体中表示x的部分和表示y的部分拆分开来,再把二进制转换为10进制。
//解码
public class ClsDecode {
	//单个染色体解码
	public double[] decode(String single,int GENE){
		//二进制数前GENE/2位为x的二进制字符串,后GENE/2位为y的二进制字符串
		int a=Integer.parseInt(single.substring(0,GENE/2),2);
		int b=Integer.parseInt(single.substring(GENE/2,GENE),2);
		double[] x = {-1,-1};//new double[2];
		x[0] = a * (6.0 - 0) / (Math.pow(2, GENE/2) - 1);	//x的基因
		x[1] = b * (6.0 - 0) / (Math.pow(2, GENE/2) - 1);	//y的基因
		
		return x;
	}
}
2 计算适应度
达到的效果是是输入群体数组,然后返回对应的适应度数组
//适应度
import java.lang.Math;
public class ClsFitness {
	//计算个体的适应度
	public double fitSingle(String str,int GENE){
		ClsDecode decode = new ClsDecode();
		double[] x =decode.decode(str,GENE);
		//适应度计算公式
		//问题:如果计算出来有负有正该怎么处理?
		double fitness = Math.sin(2 * x[0]) * Math.sin(2 * x[0]) 
				+ Math.sin(2 * x[1]) * Math.sin(2 * x[1]);	//sin+sin越大,3-sin-sin越小,即得到的值越小个体的适应度就越大
		return fitness;
	}
 
	//批量计算数组的适应度
	public double[] fitAll(String str[],int GENE){
		double [] fit = new double[str.length];
		for(int i = 0;i < str.length; i++){
			fit[i] = fitSingle(str[i],GENE);
		}
		return fit;
	}
	
	//适应度最值(返回序号)
	public int mFitNum(double fit[]){
		double m = fit[0]; 
		int n = 0;
		for(int i = 0;i < fit.length; i++){
			if(fit[i] > m){
				//最大值
				m = fit[i];
				n = i;
			}
		}
		return n;
	}
	
	//适应度最值(返回适应度)
	public double mFitVal(double fit[]){
		double m = fit[0]; 
		for(int i = 0;i < fit.length; i++){
			if(fit[i] > m){
				//最大值
				m = fit[i];
			}
		}
		return m;
	}
}
3 选择
基本思想:适应度越大的个体,被选中保留到下一代的可能性越高。使用轮盘赌算法可以达到这个目的。
为了保持群体总数不变,被淘汰的个体由随机生成的新个体补充

//轮盘赌选择
public class ClsRWS {
	ClsInit init = new ClsInit();
	ClsFitness fitness = new ClsFitness();
	
	/* fit[]适应度数组
	 * group[]群体数组
	 * GENE基因数
	 * */
	public String[] RWS(String group[], int GENE){
		double p[] = new double[group.length];	//染色体概率数组
		String[] newgroup = new String[group.length];
		double F = 0;	//适应度的和
		double[] fit = fitness.fitAll(group,GENE);	//计算适应度数组
		//求适应度的和F
		for(int i = 0; i < fit.length; i++){
			F = F +fit[i]; 			
		}
		
		//初始化p[]
		for(int i = 0; i < fit.length; i++){
			p[i] = fit[i] / F;
		}
		
		//求出适应度最大的个体maxStr,它的序号是max
		int max = fitness.mFitNum(fit);
		String maxStr = group[max];
		
		//转动轮盘,适应度大的被选中的概率大
		for (int i = 0; i < fit.length; i++){
			double r = Math.random();
			double q= 0;	//累计概率
			
			//适应度最大的个体直接继承
			if(i == max){
				newgroup[i] = maxStr;	
				p[i] = 0;	//将其概率置空	
				//System.out.println("继承的最大适应度为"+fit[i]);
				continue;
			}
 
			//遍历轮盘,寻找轮盘指针指在哪个区域
			for(int j = 0; j < fit.length; j++){
				q += p[j];
				if(r <= q){	
					newgroup[i] = group[j];	//如果被选中,保留进下一代
					p[j] = 0;	//将其概率置空					
					break;
				}	
				newgroup[i] = init.initSingle(GENE);	//如果没被选中,被外来者取代
			}
		}
		return newgroup;
	}
}
4 交叉
染色体依次两两配对,随机在一对染色体上选取一点分成两段,然后互换重组为新的两条染色体。(在交叉这一步,有更好的策略是只选取选择得到的适应度高的个体进行交叉,并且可以选择交叉后保留原个体)
    
//交叉
public class ClsCross {
	ClsFitness fitness = new ClsFitness();
	//group群体
	//GENE基因数
	//mFitNum最大适应度染色体序号
	public String[] cross(String[] group,int GENE,double crossRate){
		String temp1, temp2;
		int pos = 0;
		
		double[] fit = fitness.fitAll(group,GENE);
		int mFitNum = fitness.mFitNum(fit);	//计算适应度最大的染色体序号
		String max = group[mFitNum];
		
		for(int i = 0; i < group.length; i++){
			if(Math.random() < crossRate){
				pos = (int)(Math.random()*GENE) + 1;	//交叉点
				temp1 = group[i].substring(0, pos) + group[(i+1) % group.length].substring(pos);	//%用来防止数组越界
				temp2 = group[(i+1) % group.length].substring(0, pos) + group[i].substring(pos);
				group[i] = temp1;
				group[(i+1) % group.length] = temp2;
			}
		}
		group[0] = max;
		return group;
	}
}
5 变异
在染色体上随机选取一位,翻转其二进制位

//变异
public class ClsMutation {
	//替换String中的指定位
	//str要改动的字符串
	//num要改动的位(从0开始数)
	//pos要把这一位改动成什么
	public String replacePos(String str,int num,String pos){
		String temp;
		if(num == 0){
			temp = pos + str.substring(1);
		}else if(num == str.length()-1){
			temp = str.substring(0, str.length() - 1) + pos;
		}else{
			String temp1 = str.substring(0, num);
			String temp2 = str.substring(num + 1);
			temp = temp1 + pos + temp2;
		}
		return temp;		
	}
	
	//MP = Mutation probability变异概率
	public String[] mutation(String[] group,int GENE,double MP){
		ClsFitness fitness = new ClsFitness();
		double[] fit = fitness.fitAll(group,GENE);
		int mFitNum = fitness.mFitNum(fit);	//计算适应度最大的染色体序号
		String max = group[mFitNum];
		
		for(int i = 0; i < group.length * MP; i++){
			int n = (int) (Math.random() * GENE * group.length );	//从[0,GENE * group.length)区间取随机数
			int chrNum = (int) (n / GENE);	//取得的染色体数组下标
			int gNum = (int)(n % GENE); 	//取得的基因下标
			String temp = "";
			
			if(group[chrNum].charAt(gNum) == '0' ){
				temp = replacePos(group[chrNum], gNum, "1");
			}else{
				temp = replacePos(group[chrNum], gNum, "0");
			}
			group[chrNum] = temp;
		}
		group[0] = max;
		return group;
	}
}
6 主方法
我在选择,交叉和变异后都输出了当前群体的最大适应度,以便观察三种操作对于适应度的影响

public class GAmain {
	public static final int groupsize = 10;	//染色体数(群体中个体数)
	public static final double MP = 0.15;	//变异概率
	public static final double CP = 0.6;	//交叉概率
	public static final int ITERA = 1000;	//迭代次数
	public static final int accuracy = 8;	//精确度,选择精确到小数点后几位
	
	
	//求出精度对应的所需基因数
	int temp = (int) ((int)Math.log(6)+ accuracy*Math.log(10) );
	int GENE = temp * 2;	//基因数
	
	
	//输出原群体和适应度,测试用
	public void outAll(String[] group){
		ClsFitness fitness = new ClsFitness();
		System.out.println("原群体");		
		for(String str:group){
			System.out.println(str);
		}	
 
		double fit[] = fitness.fitAll(group,GENE);
		System.out.println("原群体适应度");
		for(double str:fit){
			System.out.println(str);
		}
	}
	
	//输出适应度最大值,以及返回最优的个体,测试用
	public int outMax(String str,String[] group){
		ClsFitness fitness = new ClsFitness();
		double[] fit = fitness.fitAll(group,GENE);
		double max = fitness.mFitVal(fit);
		System.out.println(str+"后适应度最大值为"+max);
		return fitness.mFitNum(fit);
	}
 
	public static void main(String[] args) {
		ClsInit init = new ClsInit();
		ClsFitness fitness = new ClsFitness();
		ClsRWS rws = new ClsRWS();
		ClsCross cross = new ClsCross();
		ClsMutation mutation = new ClsMutation();
		ClsDecode decode = new ClsDecode();
		GAmain ga = new GAmain();
		
		String group[] = init.initAll(ga.GENE,groupsize);	//初始化
		ga.outAll(group);
		
		for(int i = 0; i < ITERA; i++){
			group = rws.RWS(group, ga.GENE); //选择
			ga.outMax("选择",group);
			
			group = cross.cross(group,ga.GENE,CP);	//交叉
			ga.outMax("交叉",group);
			
			group = mutation.mutation(group, ga.GENE, MP);	//变异
			ga.outMax("变异",group);
			
			System.out.println("");
		}
		int max = ga.outMax("完成", group);
		double best[] = decode.decode(group[max], ga.GENE);	//解码
		double result = 3-fitness.fitSingle(group[max], ga.GENE);
		System.out.println("x1 = "+best[0]+"\n"+"x2 = "+best[1]);
		System.out.println("最小值为"+result);
	}
}

3.2 TSP问题

1 环境类
10*10的地图
有15个城市,用二维坐标点表示
种群大小、变异率、城市个数、终止代数、基因长度
 1 /**
 2  * 
 3  */
 4 package geneAlgo;
 5 
 6 import java.util.ArrayList;
 7 import java.util.List;
 8 
 9 /**
10  * @author KONGHE
11  * 
12  */
13 public class ENV {
14     public static int GENE_LENGTH = 15;//如果要添加新的城市,必须要修改这里对应
15 
16     public static int GROUP_SIZE = 50;  //种群大小
17 
18     public static double DISSOCIATION_RATE = 0.01; //变异概率
19 
20     public static double ADAPT_GOAL = 26.6;//适应度目标
21 
22     public static int SEQ_MAX_GROUP = 8;
23 
24     public static int SEQ_GROUP_MAX_LENGTH = 13;
25 
26     public static int SEQ_BREAK_START = 1;
27 
28     public static double CHANGE_FIRST_SEQ = 0.5;
29 
30     public static double IF_HAVE_CHILDREN = 0.8; //产生子代的概率是0.8
31 
32     public static int KEEP_BEST_NUM = 1; //让多少个精英传递到下一代
33 
34     public static int DIS_AFTER_MAX_GENE = 40; //多少代后终止
35 
36     // how much rate to keep bad rate while compare
37     public static double KEEP_BAD_INDIVIDAL = 0.0;//最坏个体的概率
38 
39     public static double KEEP_BAD_INDIVIDAL_MAX = 0.5;//最坏个体的最大概率
40 
41     public static double KEEP_BAD_INDIVIDAL_MIN = 0.0;//最坏个体的最小概率
42 
43     public static int CITIES_LIST[][] = new int[ENV.GENE_LENGTH][2]; //15个城市的坐标
44     static {   //初始化城市的坐标
45         CITIES_LIST[0][0] = 1;
46         CITIES_LIST[0][1] = 1;
47         CITIES_LIST[1][0] = 3;
48         CITIES_LIST[1][1] = 1;
49         CITIES_LIST[2][0] = 2;
50         CITIES_LIST[2][1] = 2;
51         CITIES_LIST[3][0] = 1;
52         CITIES_LIST[3][1] = 4;
53         CITIES_LIST[4][0] = 3;
54         CITIES_LIST[4][1] = 5;
55         CITIES_LIST[5][0] = 5;
56         CITIES_LIST[5][1] = 4;
57         CITIES_LIST[6][0] = 6;
58         CITIES_LIST[6][1] = 2;
59         CITIES_LIST[7][0] = 7;
60         CITIES_LIST[7][1] = 4;
61         CITIES_LIST[8][0] = 8;
62         CITIES_LIST[8][1] = 5;
63         CITIES_LIST[9][0] = 8;
64         CITIES_LIST[9][1] = 7;
65         CITIES_LIST[10][0] = 4;
66         CITIES_LIST[10][1] = 8;
67         CITIES_LIST[11][0] = 6;
68         CITIES_LIST[11][1] = 6;
69         CITIES_LIST[12][0] = 4;
70         CITIES_LIST[12][1] = 2;
71         CITIES_LIST[13][0] = 7;
72         CITIES_LIST[13][1] = 6;
73         CITIES_LIST[14][0] = 2;
74         CITIES_LIST[14][1] = 7;
75     }
76 
77     public static long getRandomInt(long from, long to) {   //产生随机整数
78 
79         return from > to ? from : (long) Math.round(Math.random() * (to - from) + from);
80     }
81 
82     public static boolean doOrNot(double rate) { //是否进行下一步操作
83         return Math.random() <= rate;
84     }
85 
86     public static List getGeneLinkList() { //产生种群
87         List geneList = new ArrayList();
88         for (int i = 1; i <= ENV.GENE_LENGTH - 1; i++) {
89             geneList.add(i);
90         }
91         return geneList;
92     }
93 }
94
2 主程序
 1 package geneAlgo;
 2 
 3 public class Evolution {
 4 
 5     public static void main(String[] args) {
 6         Population originalPopulation = Population.getOriginalPopulation();
 7         
 8         Individal x = Individal.getRandomIndividal();
 9 
10         int i = 0;
11         while (originalPopulation.getBestAdapt().getAdaptability() > ENV.ADAPT_GOAL) {
12 //        while(i<20){
13             originalPopulation.evolute();
14             originalPopulation.printBest();
15 //            originalPopulation.print();
16             i++;
17         }
18         originalPopulation.printBest();
19         // Individal y = Individal.getRandomIndividal();
20         // x.print();
21         // y.print();
22         // // //x.print();
23         // x.makeBabyWith(y);
24 
25         // GeneSeqMap map = new GeneSeqMap();
26         // map.addObjects(3, 4);
27         // map.addObjects(4, 6);
28         // map.addObjects(7, 3);
29         // map.print();
30         // Population x = evolution.getOriginalPopulation();
31         // //System.out.println(x.getAdaptability());
32         // x.print();
33         // System.out.println(ENV.doOrNot(1));
34     }
35 
36 }
37
3 基因工具类
 1 package geneAlgo;
 2 
 3 import java.util.HashMap;
 4 import java.util.Iterator;
 5 import java.util.Map;
 6 import java.util.Set;
 7 
 8 public class GeneSeqMap {
 9     private Map<Integer, Integer> seqMap = new HashMap<Integer, Integer>();
10 
11     public void addObjects(int a, int b) {
12         if (seqMap.containsKey(b) && seqMap.get(b) == a) {
13             seqMap.remove(b);
14             return;
15         }
16         if (seqMap.containsKey(a)) {
17             // in this project, it's not possible
18             int tempValue = seqMap.get(a);
19             seqMap.remove(a);
20             seqMap.put(b, tempValue);
21             return;
22         } else if (seqMap.containsValue(a)) {
23             Set entries = seqMap.entrySet();
24             Iterator iter = entries.iterator();
25             while (iter.hasNext()) {
26                 Map.Entry entry = (Map.Entry) iter.next();
27                 Integer key = (Integer) entry.getKey();
28                 Integer value = (Integer) entry.getValue();
29                 if (value == a) {
30                     seqMap.remove(key);
31                     seqMap.put(key, b);
32                     if(seqMap.containsKey(b)){
33                         int val=seqMap.get(b);
34                         seqMap.remove(b);
35                         seqMap.remove(key);
36                         seqMap.put(key, val);
37                     }
38                     return;
39                 }
40             }
41         }
42         if (seqMap.containsKey(b)) {
43             int value = seqMap.get(b);
44             seqMap.remove(b);
45             seqMap.put(a, value);
46             return;
47         } else if (seqMap.containsValue(b)) {
48             // it's not possible
49             return;
50         }
51         seqMap.put(a, b);
52     }
53 
54     public Integer getValueByKey(Integer key) {
55         if (seqMap.containsKey(key)) {
56             return seqMap.get(key);
57         } else {
58             return null;
59         }
60     }
61 
62     public Integer getKeyByValue(Integer value) {
63         if (seqMap.containsValue(value)) {
64             Set entries = seqMap.entrySet();
65             Iterator iter = entries.iterator();
66             while (iter.hasNext()) {
67                 Map.Entry<Integer, Integer> entry = (Map.Entry<Integer, Integer>) iter.next();
68                 Integer key = entry.getKey();
69                 Integer val = entry.getValue();
70                 if (value == val) {
71                     return key;
72                 }
73             }
74         } 
75         return null;
76         
77     }
78 
79     public void print() {
80         Set<Integer> keys = seqMap.keySet();
81         Iterator iter = keys.iterator();
82         while (iter.hasNext()) {
83             Integer key = (Integer) iter.next();
84             System.out.println(key + " " + seqMap.get(key) + ";");
85         }
86     }
87 }
88
4 个体类
  1 package geneAlgo;
  2 
  3 import java.util.ArrayList;
  4 import java.util.List;
  5 
  6 public class Individal {
  7     private int genes[] = new int[ENV.GENE_LENGTH];
  8 	//获取个体的基因
  9     public int[] getGenes() {
 10         return genes;
 11     }
 12 	//设置基因
 13     public void setGenes(int[] genes) {
 14         this.genes = genes;
 15     }
 16 	//克隆种群
 17     private int[] getClone(int[] source) {
 18         int result[] = new int[source.length];
 19         for (int i = 0; i < source.length; i++) {
 20             result[i] = source[i];
 21         }
 22         return result;
 23     }
 24 	//和partner产生子代
 25     public Individal[] makeBabyWith(Individal partner) {
 26         // parents have only two children
 27         Individal children[] = new Individal[2];
 28         int genes1[] = getClone(this.genes);
 29         int genes2[] = getClone(partner.getGenes());
 30         if (ENV.doOrNot(ENV.IF_HAVE_CHILDREN)) {
 31             GeneSeqMap seqMap = new GeneSeqMap();  // used to correct illegal exchange
 32             List<Integer> seqBreak = new ArrayList<Integer>();
 33             List<Integer> seqChange = new ArrayList<Integer>();
 34             seqBreak.add(ENV.SEQ_BREAK_START);
 35             int i = (int) ENV.getRandomInt(1, ENV.GENE_LENGTH - ENV.SEQ_BREAK_START - 1);
 36             int j = 1;
 37             while (seqBreak.get(seqBreak.size() - 1) + i <= ENV.GENE_LENGTH - 1 && j <= ENV.SEQ_MAX_GROUP) {
 38                 seqBreak.add(seqBreak.get(seqBreak.size() - 1) + i);
 39                 i = (int) ENV.getRandomInt(i + 1, ENV.GENE_LENGTH);
 40             }
 41 
 42             j = 0;
 43             boolean changeFirstSeq = ENV.doOrNot(ENV.CHANGE_FIRST_SEQ);
 44             for (i = 0; i < seqBreak.size(); i++) {
 45                 int nextBreakPos = (i == seqBreak.size() - 1 ? ENV.GENE_LENGTH : seqBreak.get(i + 1));
 46                 for (int m = seqBreak.get(i); m < nextBreakPos; m++) {
 47                     if ((j == 0 && changeFirstSeq) || (j == 1 && !changeFirstSeq)) {
 48                         seqMap.addObjects(genes1[m], genes2[m]);
 49                         int temp = genes1[m];
 50                         genes1[m] = genes2[m];
 51                         genes2[m] = temp;
 52                         seqChange.add(m);
 53                     } else {
 54                         break;
 55                     }
 56                 }
 57                 j = 1 - j;
 58             }
 59 
 60             for (int m = 0; m < ENV.GENE_LENGTH; m++) {
 61                 if (seqChange.contains(m)) {
 62                     continue;
 63                 }
 64                 Integer genes1Change = seqMap.getKeyByValue(genes1[m]);
 65                 Integer genes2Change = seqMap.getValueByKey(genes2[m]);
 66                 if (genes1Change != null) {
 67                     genes1[m] = genes1Change.intValue();
 68                 }
 69                 if (genes2Change != null) {
 70                     genes2[m] = genes2Change.intValue();
 71                 }
 72             }
 73 
 74         }
 75         children[0] = new Individal();
 76         children[1] = new Individal();
 77         children[0].setGenes(genes1);
 78         children[1].setGenes(genes2);
 79         return children;
 80     }
 81 	//变异
 82     public void dissociation() {
 83         // change own gene sequence by dissociation percente.
 84         for (int i = 1; i < genes.length; i++) {
 85             // boolean ifChange = ENV.doOrNot(ENV.DISSOCIATION_RATE * (ENV.GENE_LENGTH - 1));
 86             boolean ifChange = ENV.doOrNot(ENV.DISSOCIATION_RATE);
 87             if (ifChange) {
 88                 // long start = ENV.getRandomInt(1, ENV.GENE_LENGTH - 1);
 89                 long start = i;
 90                 long goStep = ENV.getRandomInt(1, ENV.GENE_LENGTH - 2);
 91                 long changePos = start + goStep;
 92                 if (changePos >= ENV.GENE_LENGTH) {
 93                     changePos = changePos - ENV.GENE_LENGTH + 1;
 94                 }
 95                 int temp = genes[(int) start];
 96                 genes[(int) start] = genes[(int) changePos];
 97                 genes[(int) changePos] = temp;
 98             }
 99 
100         }
101     }
102 	//打印个体基因
103     public void print() {
104         for (int i = 0; i < genes.length; i++) {
105             System.out.print(genes[i] + ";");
106         }
107         System.out.print("--" + this.getAdaptability());
108         System.out.println();
109     }
110 	//获取适应度
111     public double getAdaptability() {
112         int seq[] = this.getGenes();
113         double totalLength = 0;
114         for (int i = 0; i < seq.length - 1; i++) {
115             double length = Math.hypot(ENV.CITIES_LIST[seq[i]][0] - ENV.CITIES_LIST[seq[i + 1]][0],
116                     ENV.CITIES_LIST[seq[i]][1] - ENV.CITIES_LIST[seq[i + 1]][1]);
117             totalLength += length;
118         }
119         return totalLength;
120     }
121 	//产生随机个体
122     public static Individal getRandomIndividal() {
123         Individal individal = new Individal();
124         int[] geneSeq = individal.getGenes();
125         List geneList = ENV.getGeneLinkList();
126         int tempLength = geneList.size();
127         for (int i = 1; i <= tempLength; i++) {
128             long random = ENV.getRandomInt(0, ENV.GENE_LENGTH * 5);
129             int seq = (int) random % geneList.size();
130             geneSeq[i] = (Integer) geneList.get(seq);
131             geneList.remove(seq);
132         }
133         return individal;
134     }
135 
136 }
137
5 群体类
1 package geneAlgo;
  2 
  3 import java.util.ArrayList;
  4 import java.util.List;
  5 
  6 public class Population {
  7 
  8     private long eras = 0;
  9 
 10     private double historyBest = ENV.ADAPT_GOAL + 50;
 11 
 12     private List<Individal> individals = new ArrayList<Individal>();
 13 
 14     public List<Individal> getIndividals() {
 15         return individals;
 16     }
 17 
 18     public void setIndividals(List<Individal> individals) {
 19         this.individals = individals;
 20     }
 21 
 22     public boolean addIndividal(Individal individal) {
 23         boolean result = true;
 24         if (this.individals.size() >= ENV.GROUP_SIZE) {
 25             result = false;
 26         } else {
 27             this.individals.add(individal);
 28         }
 29         return result;
 30     }
 31 
 32     public void print() {
 33         for (int i = 0; i < individals.size(); i++) {
 34             individals.get(i).print();
 35         }
 36     }
 37 
 38     public static Population getOriginalPopulation() {
 39         Population original = new Population();
 40         Individal a = Individal.getRandomIndividal();
 41         Individal b = Individal.getRandomIndividal();
 42         while (original.addIndividal(a.getAdaptability() < b.getAdaptability() ? a : b)) {
 43             a = Individal.getRandomIndividal();
 44             b = Individal.getRandomIndividal();
 45         }
 46         return original;
 47     }
 48 
 49     public void evolute() {
 50         Population evoPool = new Population();
 51         // make evolution pool
 52         while (evoPool.individals.size() < ENV.GROUP_SIZE) {
 53             int indi1 = (int) ENV.getRandomInt(0, ENV.GROUP_SIZE - 1);
 54             int indi2 = (int) ENV.getRandomInt(0, ENV.GROUP_SIZE - 1);
 55             while (indi2 == indi1) {
 56                 indi2 = (int) ENV.getRandomInt(0, ENV.GROUP_SIZE - 1);
 57             }
 58             if (this.individals.get(indi1).getAdaptability() == this.individals.get(indi2).getAdaptability()) {
 59                 if (ENV.KEEP_BAD_INDIVIDAL <= ENV.KEEP_BAD_INDIVIDAL_MAX) {
 60                     ENV.KEEP_BAD_INDIVIDAL += 0.0004;
 61                 }
 62             } else {
 63                 if (ENV.KEEP_BAD_INDIVIDAL >= ENV.KEEP_BAD_INDIVIDAL_MIN)
 64                     ENV.KEEP_BAD_INDIVIDAL -= 0.00005;
 65             }
 66             boolean ifKeepBad = ENV.doOrNot(ENV.KEEP_BAD_INDIVIDAL);
 67             if (ifKeepBad) {
 68                 evoPool.addIndividal(this.individals.get(indi1).getAdaptability() > this.individals.get(indi2)
 69                         .getAdaptability() ? this.individals.get(indi1) : this.individals.get(indi2));
 70             } else {
 71                 evoPool.addIndividal(this.individals.get(indi1).getAdaptability() <= this.individals.get(indi2)
 72                         .getAdaptability() ? this.individals.get(indi1) : this.individals.get(indi2));
 73             }
 74             // if (this.individals.get(indi1).getAdaptability() <= this.individals.get(indi2).getAdaptability()) {
 75             // evoPool.addIndividal(individals.get(indi1));
 76             // } else {
 77             // evoPool.addIndividal(individals.get(indi2));
 78             // }
 79         }
 80         Population newPopulation = new Population();
 81         for (int i = 0; i < ENV.GROUP_SIZE - 1; i = i + 2) {
 82             Individal children[] = evoPool.getIndividals().get(i).makeBabyWith(evoPool.getIndividals().get(i + 1));
 83             children[0].dissociation();
 84             children[1].dissociation();
 85             newPopulation.addIndividal(children[0]);
 86             newPopulation.addIndividal(children[1]);
 87         }
 88         this.setIndividals(newPopulation.getIndividals());
 89         this.eras++;
 90     }
 91 
 92     public long getEras() {
 93         return eras;
 94     }
 95 
 96     public void setEras(long eras) {
 97         this.eras = eras;
 98     }
 99 
100     public void printBest() {
101         Individal x = getBestAdapt();
102         x.print();
103         System.out.print("eras:" + this.getEras());
104         System.out.print(" historyBest:" + this.historyBest);
105         System.out.print(" keep bad rate:" + ENV.KEEP_BAD_INDIVIDAL);
106         System.out.println();
107     }
108 
109     public Individal getBestAdapt() {
110         Individal x = null;
111         for (int i = 0; i < ENV.GROUP_SIZE; i++) {
112             if (x == null) {
113                 x = this.getIndividals().get(i);
114             } else {
115                 if (this.getIndividals().get(i).getAdaptability() < x.getAdaptability()) {
116                     x = this.getIndividals().get(i);
117                 }
118             }
119         }
120         if (x.getAdaptability() < this.historyBest) {
121             this.historyBest = x.getAdaptability();
122         }
123         return x;
124     }
125 }
126

二 蚁群算法

1 蚁群算法简介

image.png

2 蚁群算法应用

image.png

3 蚁群算法基本原理

image.png

4 算法步骤

image.png

4.1算法参数

image.png

image.png

4.2构建路径

image.png
image.png
image.png

4.3 更新信息素

image.png
image.png
image.png

4.4 判断是否达到终止条件

image.png

5 TSP例题Java解决

数据
NAME : att48
COMMENT : 48 capitals of the US (Padberg/Rinaldi)
TYPE : TSP
DIMENSION : 48
EDGE_WEIGHT_TYPE : ATT
NODE_COORD_SECTION
1 6734 1453
2 2233 10
3 5530 1424
4 401 841
5 3082 1644
6 7608 4458
7 7573 3716
8 7265 1268
9 6898 1885
10 1112 2049
11 5468 2606
12 5989 2873
13 4706 2674
14 4612 2035
15 6347 2683
16 6107 669
17 7611 5184
18 7462 3590
19 7732 4723
20 5900 3561
21 4483 3369
22 6101 1110
23 5199 2182
24 1633 2809
25 4307 2322
26 675 1006
27 7555 4819
28 7541 3981
29 3177 756
30 7352 4506
31 7545 2801
32 3245 3305
33 6426 3173
34 4608 1198
35 23 2216
36 7248 3779
37 7762 4595
38 7392 2244
39 3484 2829
40 6271 2135
41 4985 140
42 1916 1569
43 7280 4899
44 7509 3239
45 10 2676
46 6807 2993
47 5185 3258
48 3023 1942
EOF

5.1 蚁群算法类

import java.io.*;
/**
 *蚁群优化算法,用来求解TSP问题
 */
public class ACO {
    ant []ants; //定义蚂蚁群
    int antcount;//蚂蚁的数量
    int [][]distance;//表示城市间距离
    double [][]tao;//信息素矩阵
    int citycount;//城市数量
    int[]besttour;//求解的最佳路径
    int bestlength;//求的最优解的长度
    //filename tsp数据文件
    //antnum 系统用到蚂蚁的数量
    public void init(String filename,int antnum) throws FileNotFoundException, IOException{
        antcount=antnum;
        ants=new ant[antcount];
        //读取数据tsp里的数据包括第I个城市与城市的X,Y坐标
        int[] x;
        int[] y;
        String strbuff;
        BufferedReader tspdata = new BufferedReader(new InputStreamReader(new FileInputStream(filename)));
        strbuff = tspdata.readLine();//读取第一行,城市总数(按文件格式读取)
        citycount = Integer.valueOf(strbuff);
        distance = new int[citycount][citycount];
        x = new int[citycount];
        y = new int[citycount];
        for (int citys = 0; citys < citycount; citys++) {
            strbuff = tspdata.readLine();
            String[] strcol = strbuff.split(" ");
            x[citys] = Integer.valueOf(strcol[1]);//读取每排数据的第2二个数字即横坐标
            y[citys] = Integer.valueOf(strcol[2]);
        }
        //计算两个城市之间的距离矩阵,并更新距离矩阵
        for (int city1 = 0; city1 < citycount - 1; city1++) {
            distance[city1][city1] = 0;
            for (int city2 = city1 + 1; city2 < citycount; city2++) {
                distance[city1][city2] = (int) (Math.sqrt((x[city1] - x[city2]) * (x[city1] - x[city2])
                        + (y[city1] - y[city2]) * (y[city1] - y[city2])));
                distance[city2][city1] = distance[city1][city2];//距离矩阵是对称矩阵
            }
        }
        distance[citycount - 1][citycount - 1] = 0;
        //初始化信息素矩阵
        tao=new double[citycount][citycount];
        for(int i=0;i<citycount;i++)
        {
            for(int j=0;j<citycount;j++){
                tao[i][j]=0.1;
            }
        }
        bestlength=Integer.MAX_VALUE;
        besttour=new int[citycount+1];
        //随机放置蚂蚁
        for(int i=0;i<antcount;i++){
            ants[i]=new ant();
            ants[i].RandomSelectCity(citycount);
        }
    }
    //maxgen ACO的最多循环次数
    public void run(int maxgen){
        for(int runtimes=0;runtimes<maxgen;runtimes++){
            //每次迭代,所有蚂蚁都要跟新一遍,走一遍
            //System.out.print("no>>>"+runtimes);
            //每一只蚂蚁移动的过程
            for(int i=0;i<antcount;i++){
                for(int j=1;j<citycount;j++){
                    ants[i].SelectNextCity(j,tao,distance);//每只蚂蚁的城市规划
                }
                //计算蚂蚁获得的路径长度
                ants[i].CalTourLength(distance);
                if(ants[i].tourlength<bestlength){
                    //保留最优路径
                    bestlength=ants[i].tourlength;
                    //runtimes仅代表最大循环次数,但是只有当,有新的最优路径的时候才会显示下列语句。
                    //如果后续没有更优解(收敛),则最后直接输出。
                    System.out.println("第"+runtimes+"代(次迭代),发现新的最优路径长度:"+bestlength);
                    for(int j=0;j<citycount+1;j++)
                        besttour[j]=ants[i].tour[j];//更新路径
                }
            }
            //更新信息素矩阵
            UpdateTao();
            //重新随机设置蚂蚁
            for(int i=0;i<antcount;i++){
                ants[i].RandomSelectCity(citycount);
            }
        }
       }
    /**
     * 更新信息素矩阵
     */
    private void UpdateTao(){
        double rou=0.5;
        //信息素挥发
        for(int i=0;i<citycount;i++)
            for(int j=0;j<citycount;j++)
                tao[i][j]=tao[i][j]*(1-rou);
        //信息素更新
        for(int i=0;i<antcount;i++){
            for(int j=0;j<citycount;j++){
                tao[ants[i].tour[j]][ants[i].tour[j+1]]+=1.0/ants[i].tourlength;
            }
        }
    }
    /* 输出程序运行结果
     */
    public void ReportResult(){
        System.out.println("最优路径长度是"+bestlength);
        System.out.println("蚁群算法最优路径输出:");
        for(int j=0;j<citycount+1;j++)
            System.out.print( besttour[j]+">>");//输出最优路径
    }
}

5.2 蚂蚁类

import java.util.Random;
/*
 蚂蚁类
 */
public class ant {
    /**
     * 蚂蚁获得的路径
     */
    public int[]tour;//参观城市顺序
    //unvisitedcity 取值是0或1,1表示没有访问过,0表示访问过
    int[] unvisitedcity;
    /**
     * 蚂蚁获得的路径长度
     */
    public int tourlength;//某蚂蚁所走路程总长度。
    int citys;//城市个数
/**
 * 随机分配蚂蚁到某个城市中
 * 同时完成蚂蚁包含字段的初始化工作
 * @param citycount 总的城市数量
 */
    public void RandomSelectCity(int citycount){
        citys=citycount;
        unvisitedcity=new int[citycount];
        tour=new int[citycount+1];
        tourlength=0;
        for(int i=0;i<citycount;i++){
            tour[i]=-1;
            unvisitedcity[i]=1;
        }//初始化各个变量

        long r1 = System.currentTimeMillis();//获取当前时间
        Random rnd=new Random(r1);
        int firstcity=rnd.nextInt(citycount);//随机指定第一个城市
        unvisitedcity[firstcity]=0;//0表示访问过
        tour[0]=firstcity;//起始城市
    }
    /**
     * 选择下一个城市
     * @param index 需要选择第index个城市了
     * @param tao   全局的信息素信息
     * @param distance  全局的距离矩阵信息
     */
    public void SelectNextCity(int index,double[][]tao,int[][]distance){
        double []p;
        p=new double[citys];//下一步要走的城市的选中概率
        //计算选中概率所需系数。
        double alpha=1.0;
        double beta=2.0;
        double sum=0;
        int currentcity=tour[index-1];//蚂蚁所处当前城市
        //计算公式中的分母部分(为下一步计算选中概率使用)
        for(int i=0;i<citys;i++){
            if(unvisitedcity[i]==1)//没走过
                sum+=(Math.pow(tao[currentcity][i], alpha)*
                        Math.pow(1.0/distance[currentcity][i], beta));
        }
        //计算每个城市被选中的概率
        for(int i=0;i<citys;i++){
            if(unvisitedcity[i]==0)
                p[i]=0.0;//城市走过了,选中概率就是0
            else{
                //没走过,下一步要走这个城市的概率是?
                p[i]=(Math.pow(tao[currentcity][i], alpha)*
                        Math.pow(1.0/distance[currentcity][i], beta))/sum;
            }
        }
        long r1 = System.currentTimeMillis();
        Random rnd=new Random(r1);
        double selectp=rnd.nextDouble();
        //轮盘赌选择一个城市;
        double sumselect=0;
        int selectcity=-1;
        //城市选择随机,直到n个概率加起来大于随机数,则选择该城市
        for(int i=0;i<citys;i++){//每次都是顺序走。。。。。
            sumselect+=p[i];
            if(sumselect>=selectp){
                selectcity=i;
                break;
            }
        }
        if (selectcity==-1)//这个城市没有走过
            System.out.println();
        tour[index]=selectcity;
        unvisitedcity[selectcity]=0;
    }
    /**
     * 计算蚂蚁获得的路径的长度
     * @param distance  全局的距离矩阵信息
     */
    public void CalTourLength(int [][]distance){
        tourlength=0;
        tour[citys]=tour[0];//第一个城市等于最后一个要到达的城市
        for(int i=0;i<citys;i++){
            tourlength+=distance[tour[i]][tour[i+1]];//从A经过每个城市仅一次,最后回到A的总长度。
        }
    }
}

5.3 主程序

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
//蚁群算法求解旅行商问题,TSP数据来源
//http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/
//数据中包括城市总量,每个城市的横纵坐标
public class Main {
    /**
     * @param args the command line arguments
     */
    public static void main(String[] args) {
        ACO aco;
        aco=new ACO();
        try {
            aco.init("att48.txt", 100);//城市信息文件,蚂蚁数量
            aco.run(1000);//迭代次数
            aco.ReportResult();
        } catch (FileNotFoundException ex) {
            Logger.getLogger(Main.class.getName()).log(Level.SEVERE, null, ex);
        } catch (IOException ex) {
            Logger.getLogger(Main.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
}

三 神经网络

Tensorflow教学:http://c.biancheng.net/tensorflow/

1 神经网络简介

	生物神经网络主要是指人脑的神经网络,它是人工神经网络的技术原型。人脑是人类思维的物质基础,思维的功能定位在大脑皮层,后者含有大约10^11个神经元,每个神经元又通过神经突触与大约103个其它神经元相连,形成一个高度复杂高度灵活的动态网络。作为一门学科,生物神经网络主要研究人脑神经网络的结构、功能及其工作机制,意在探索人脑思维和智能活动的规律。
	人工神经网络是生物神经网络在某种简化意义下的技术复现,作为一门学科,它的主要任务是根据生物神经网络的原理和实际应用的需要建造实用的人工神经网络模型,设计相应的学习算法,模拟人脑的某种智能活动,然后在技术上实现出来用以解决实际问题。因此,生物神经网络主要研究智能的机理;人工神经网络主要研究智能机理的实现,两者相辅相成

1.1智慧

	对于人类智慧奥秘的探索,不同时代、学科背景的人对于智慧的理解及其实现方法有着不同的思想主张。有的主张用显式逻辑体系搭建人工智能系统,即符号主义。有的主张用数学模型模拟大脑组成以实现智慧,即联结主义,这也就是我们本文讨论的方向。
	那大脑为什么能够思考?科学家发现,原因在于人体的神经网络,而神经网络的基本组成就是神经元:

image.png

1、外部刺激通过神经元的神经末梢,转化为电信号,传导到神经元
2、神经元的树突接收电信号,由神经元处理是否达到激活阈值再输出兴奋或者抑制电信号,最后由轴突将信号传递给其它细胞
3、无数神经元构成神经中枢。神经中枢综合各种信号,做出判断
4、人体根据神经中枢的指令,对外部刺激做出反应

1.2神经元

	既然智慧的基础是神经元,而正因为神经元这些特点才使大脑具有强大的 “运算及决策的能力”,科学家以此为原理发明了人工神经元数学模型,并以神经元为基础而组合成人工神经网络模型。(注:下文谈到的神经元都特指人工神经元)

image.png

	如上图就是人工神经元的基本结构。它可以输入一定维数的输入(如:3维的输入,x1,x2, x3),每个输入都相要乘上相应的权重值(如:w0,w1,w2),乘上每一权重值的作用可以视为对每一输入的加权,也就是对每一输入的神经元对它的重视程度是不一样的。
	接下来神经元将乘上权重的每个输入做下求和(也就是加权求和),并加上截距项(截距项b可以视为对神经元阈值的直接调整),最后由激活函数(f)非线性转换为最终输出值。
	激活函数的种类很多,有sigmoid,tanh,sign,relu,softmax等等(下一专题会讨论下激活函数)。激活函数的作用是在神经元上实现一道非线性的运算,以通用万能近似定理——“如果一个前馈神经网络具有线性输出层和至少一层隐藏层,只要给予网络足够数量的神经元,便可以实现以足够高精度来逼近任意一个在 ℝn 的紧子集 (Compact subset) 上的连续函数”所表明,激活函数是深度神经网络学习拟合任意函数的前提

1.3 神经元到逻辑回归

单个神经元且其激活函数为sigmoid时,既是我们熟知的逻辑回归的模型结构。

image.png

	逻辑回归是一种广义线性的分类模型且其模型结构可以视为单层的神经网络,由一层输入层、一层仅带有一个sigmoid激活函数的神经元的输出层组成,而无隐藏层。(注:计算网络层数不计入输入层)
	在逻辑回归模型结构中,我们输入数据特征x,通过输出层神经元激活函数σ(sigmoid函数)将输入的特征经由sigmoid(wx + b)的计算后非线性转换为0~1区间的数值后输出。学习训练过程是通过梯度下降学到合适的参数w的模型 Y=sigmoid(wx + b),使得模型输出值Y与实际值y的误差最小。
	单层的逻辑回归模型,用数学式表示也就是: Y=sigmoid(wx + b)

1.4 逻辑回归到深度神经网络

	基于前面的介绍可以知道,神经网络也就是神经元按层次连接而成的网络,逻辑回归是单层的神经网络,当我们给仅有输入层及输出层的单层神经网络中间接入至少一层的隐藏层,也就是深度神经网络了

image.png

深度神经网络包含了三种网络层:输入层、隐藏层及输出层。
  1 输入层:为数据特征输入层,输入数据特征个数就对应着网络的神经元数。
  2 隐藏层:即网络的中间层,隐藏层层数可以为0或者很多层,其作用接受前一层网络输出作为当前的输入值,并计算输出当前结果到下一层。隐藏层是神经网络性能的关键,通常由含激活函数的神经元组成,以进一步加工出高层次抽象的特征,以增强网络的非线性表达。隐藏网络层数直接影响模型的拟合效果。
  3 输出层:为最终结果输出的网络层。输出层的神经元个数代表了分类标签的个数(注:在做二分类时,如果输出层的激活函数采用sigmoid,输出层的神经元个数为1个;如果采用softmax分类器,输出层神经元个数为2个)


神经网络的学习过程是这样的:
	数据特征(x)从输入层输入,每层的计算结果由前一层传递到下一层(称为前向传播),最终到输出层输出计算结果。每个网络层由一定数量的神经元组成,神经元可以视为一个个的计算单元,对输入进行加权求和wx + b,神经元内还可以包含激活函数,可以对加权求和的结果进一步做非线性的计算,如sigmoid(wx + b) 。可见,其计算结果由神经元包含的权重(即模型参数w)直接控制。
	而模型参数w是根据与实际数据对比的差异,用反向传播算法去学习出来的。(称为反向传播)
	
最终,这里反问深度神经网络模型是什么呢?也一样是个一层套一层的复合函数:
	y = f(f(..f(wx+b)))

image.png

1.5 应用

image.png

分类问题:图像识别、垃圾邮件识别
回归问题:股价预测、房价预测
排序问题:点击率预估、推荐
生成问题:图像生成、图像风格转换、图像文字描述

2 神经网络入门

2.1 机器学习深度学习简介

1 机器学习

image.png

image.png

2 深度学习
深度学习是机器学习的分支

image.png
image.png

3 深度学习分类
卷积神经网络
循环神经网路
自动编码器
稀疏编码
深度信念网络
限制波尔兹曼机
深度学习+强化学习=深度强化学习

2.2 神经元-逻辑回归模型

1 单输出-最小的神经元

image.png

例子:
image.png

2 二分类-逻辑斯啼回归模型
x为-oo~+oo,y输出为0~1,适合用来表示概率问题

image.png

2.3神经元多输出

1 多输出-多个神经元
增加神经元,就从二分类问题变成多输出
W从向量变成矩阵
输出变成两个Y0 Y1

image.png

image.png

2 多分类-逻辑斯啼回归模型
参考二分类的做法,利用归一化的思路,对多回归进行归一化

image.png

image.png

image.png

3 目标函数

单目标

image.png

经过model后得到的1的概率是0.8,得到0的概率是0.2

多目标

image.png

假设有五个神经元,五种结果0 1 2 3 4 
经过模型计算得到[0 1 2 3 4]=[0.1 0.2 0.25 0.6 0.05]
在进行one-hot编码Loss=abs(y1-y1’)=[0 0 0 1 0]-y1'=[0.1 0.2 0.25 0.6 0.25]=1.2

那么得到3的目标函数的结果为1.2
3 平方差损失函数

image.png

4 交叉熵损失函数

image.png

5 神经网络训练

image.png

2.4 梯度下降

1 梯度下降

image.png

image.png

α代表步长,是人为设置的值

3 Java代码实现

3.1 BP实现

	每层都含有一个一维X特征矩阵即为输入数据,一个二维W权值矩阵,一个一维的误差矩阵error,同时该神经网络中还包含了一个一维的目标矩阵target,记录样本的真实类标。
	1 X特征矩阵:第一层隐含层的X矩阵的长度为输入层输入数据的特征个数+1,隐含层的X矩阵的长度则是上一层的节点的个数+1,X[0]=1。
	2 W权值矩阵:第一维的长度设计为节点(即神经元)的个数,第二维的长度设计为上一层节点的个数+1;W[0][0]为该节点的偏置量
	3 error误差矩阵:数组长度设计为该层的节点个数。 
	4 目标矩阵target:输出层的节点个数与其一致。
	5 激活函数:采用sigmoid函数:1/1+e-x
public class Bp {

    private double[] hide1_x; 输入层即第一层隐含层的输入;hide1_x[数据的特征数目+1], hide1_x[0]为1
    private double[][] hide1_w;// 隐含层权值,hide1_w[本层的节点的数目][数据的特征数目+1];hide_w[0][0]为偏置量
    private double[] hide1_errors;// 隐含层的误差,hide1_errors[节点个数]

    private double[] out_x;// 输出层的输入值即第二次层隐含层的输出 out_x[上一层的节点数目+1], out_x[0]为1
    private double[][] out_w;// 输出层的权值 hide1_w[节点的数目][上一层的节点数目+1]//
                                // out_w[0][0]为偏置量
    private double[] out_errors;// 输出层的误差 hide1_errors[节点个数]

    private double[] target;// 目标值,target[输出层的节点个数]

    private double rate;// 学习速率

    public Bp(int input_node, int hide1_node, int out_node, double rate) {
        super();

        // 输入层即第一层隐含层的输入
        hide1_x = new double[input_node + 1];

        // 第一层隐含层
        hide1_w = new double[hide1_node][input_node + 1];
        hide1_errors = new double[hide1_node];

        // 输出层
        out_x = new double[hide1_node + 1];
        out_w = new double[out_node][hide1_node + 1];
        out_errors = new double[out_node];

        target = new double[out_node];

        // 学习速率
        this.rate = rate;
        init_weight();// 1.初始化网络的权值
    }

    /**
     * 初始化权值
     */
    public void init_weight() {

        set_weight(hide1_w);
        set_weight(out_w);
    }

    /**
     * 初始化权值
     *
     * @param w
     */
    private void set_weight(double[][] w) {
        for (int i = 0, len = w.length; i != len; i++)
            for (int j = 0, len2 = w[i].length; j != len2; j++) {
                w[i][j] = 0;
            }
    }

    /**
     * 获取原始数据
     *
     * @param Data
     *            原始数据矩阵
     */
    private void setHide1_x(double[] Data) {
        if (Data.length != hide1_x.length - 1) {
            throw new IllegalArgumentException("数据大小与输出层节点不匹配");
        }
        System.arraycopy(Data, 0, hide1_x, 1, Data.length);
        hide1_x[0] = 1.0;
    }

    /**
     * @param target
     *            the target to set
     */
    private void setTarget(double[] target) {
        this.target = target;
    }

    /**
     * 2.训练数据集
     *
     * @param TrainData
     *            训练数据
     * @param target
     *            目标
     */
    public void train(double[] TrainData, double[] target) {
        // 2.1导入训练数据集和目标值
        setHide1_x(TrainData);
        setTarget(target);

        // 2.2:向前传播得到输出值;
        double[] output = new double[out_w.length + 1];
        forword(hide1_x, output);

        // 2.3、方向传播:
        backpropagation(output);

    }

    /**
     * 反向传播过程
     *
     * @param output
     *            预测结果
     */
    public void backpropagation(double[] output) {

        // 2.3.1、获取输出层的误差;
        get_out_error(output, target, out_errors);
        // 2.3.2、获取隐含层的误差;
        get_hide_error(out_errors, out_w, out_x, hide1_errors);
         2.3.3、更新隐含层的权值;
        update_weight(hide1_errors, hide1_w, hide1_x);
        // * 2.3.4、更新输出层的权值;
        update_weight(out_errors, out_w, out_x);
    }

    /**
     * 预测
     *
     * @param data
     *            预测数据
     * @param output
     *            输出值
     */
    public void predict(double[] data, double[] output) {

        double[] out_y = new double[out_w.length + 1];
        setHide1_x(data);
        forword(hide1_x, out_y);
        System.arraycopy(out_y, 1, output, 0, output.length);

    }


    public void update_weight(double[] err, double[][] w, double[] x) {

        double newweight = 0.0;
        for (int i = 0; i < w.length; i++) {
            for (int j = 0; j < w[i].length; j++) {
                newweight = rate * err[i] * x[j];
                w[i][j] = w[i][j] + newweight;
            }

        }
    }

    /**
     * 获取输出层的误差
     *
     * @param output
     *            预测输出值
     * @param target
     *            目标值
     * @param out_error
     *            输出层的误差
     */
    public void get_out_error(double[] output, double[] target, double[] out_error) {
        for (int i = 0; i < target.length; i++) {
            out_error[i] = (target[i] - output[i + 1]) * output[i + 1] * (1d - output[i + 1]);
        }

    }

    /**
     * 获取隐含层的误差
     *
     * @param NeLaErr
     *            下一层的误差
     * @param Nextw
     *            下一层的权值
     * @param output 下一层的输入
     * @param error
     *            本层误差数组
     */
    public void get_hide_error(double[] NeLaErr, double[][] Nextw, double[] output, double[] error) {

        for (int k = 0; k < error.length; k++) {
            double sum = 0;
            for (int j = 0; j < Nextw.length; j++) {
                sum += Nextw[j][k + 1] * NeLaErr[j];
            }
            error[k] = sum * output[k + 1] * (1d - output[k + 1]);
        }
    }

    /**
     * 向前传播
     *
     * @param x
     *            输入值
     * @param output
     *            输出值
     */
    public void forword(double[] x, double[] output) {

        // 2.2.1、获取隐含层的输出
        get_net_out(x, hide1_w, out_x);
        // 2.2.2、获取输出层的输出
        get_net_out(out_x, out_w, output);

    }

    /**
     * 获取单个节点的输出
     *
     * @param x
     *            输入矩阵
     * @param w
     *            权值
     * @return 输出值
     */
    private double get_node_put(double[] x, double[] w) {
        double z = 0d;

        for (int i = 0; i < x.length; i++) {
            z += x[i] * w[i];
        }
        // 2.激励函数
        return 1d / (1d + Math.exp(-z));
    }

    /**
     * 获取网络层的输出
     *
     * @param x
     *            输入矩阵
     * @param w
     *            权值矩阵
     * @param net_out
     *            接收网络层的输出数组
     */
    private void get_net_out(double[] x, double[][] w, double[] net_out) {

        net_out[0] = 1d;
        for (int i = 0; i < w.length; i++) {
            net_out[i + 1] = get_node_put(x, w[i]);
        }

    }

}

3.2 测试例子

//训练模型 进行正负数奇偶数的预测

public class Test {

    /**
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {


        Bp bp = new Bp(32, 15, 4, 0.05);

        Random random = new Random();

        List<Integer> list = new ArrayList<Integer>();
        for (int i = 0; i != 6000; i++) {
            int value = random.nextInt();
            list.add(value);
        }

        for (int i = 0; i !=25; i++) {
            for (int value : list) {
                double[] real = new double[4];
                if (value >= 0)
                    if ((value & 1) == 1)
                        real[0] = 1;
                    else
                        real[1] = 1;
                else if ((value & 1) == 1)
                    real[2] = 1;
                else
                    real[3] = 1;

                double[] binary = new double[32];
                int index = 31;
                do {
                    binary[index--] = (value & 1);
                    value >>>= 1;
                } while (value != 0);

                bp.train(binary, real);



            }
        }




        System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");

        while (true) {

            byte[] input = new byte[10];
            System.in.read(input);
            Integer value = Integer.parseInt(new String(input).trim());
            int rawVal = value;
            double[] binary = new double[32];
            int index = 31;
            do {
                binary[index--] = (value & 1);
                value >>>= 1;
            } while (value != 0);

            double[] result =new double[4];
             bp.predict(binary,result);


            double max = -Integer.MIN_VALUE;
            int idx = -1;

            for (int i = 0; i != result.length; i++) {
                if (result[i] > max) {
                    max = result[i];
                    idx = i;
                }
            }

            switch (idx) {
            case 0:
                System.out.format("%d是一个正奇数\n", rawVal);
                break;
            case 1:
                System.out.format("%d是一个正偶数\n", rawVal);
                break;
            case 2:
                System.out.format("%d是一个负奇数\n", rawVal);
                break;
            case 3:
                System.out.format("%d是一个负偶数\n", rawVal);
                break;
            }
        }
    }
}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
首先,我们需要了解什么是联邦学习和Q-learning算法。 联邦学习是一种分布式机器学习技术,它允许多个参与者(例如设备或组织)共同训练一个模型,而不需要将他们的数据集集中在一起。每个参与者只需在本地训练模型,然后将更新的模型参数发送给中央服务器进行聚合,生成一个全局模型。这种方式可以保护数据隐私和安全,同时提高模型的泛化能力。 Q-learning算法是一种基于强化学习算法,它可以用于解决各种问题,包括优化问题。该算法通过不断地学习和更新动作值函数,从而找到最优策略。 现在,我们可以将这两个概念结合起来,使用Q-learning算法实现联邦学习优化算法。具体步骤如下: 1. 定义状态和动作 在联邦学习中,我们可以将每个参与者的数据集视为一个状态。动作可以是参与者更新本地模型参数的步长或其他超参数。 2. 定义奖励函数 奖励函数可以衡量全局模型的性能。例如,可以使用全局模型在测试数据集上的准确率作为奖励函数。这将鼓励参与者采取能够提高全局模型性能的动作。 3. 定义Q-table Q-table是一个表格,它记录了在每个状态下采取每个动作的预期回报。我们可以初始化Q-table为零,并在每次参与者更新本地模型参数时更新它。 4. 实现Q-learning算法 在每一轮迭代中,我们可以使用Q-learning算法来更新Q-table。具体来说,我们可以使用以下公式: Q(s, a) = (1 - alpha) * Q(s, a) + alpha * (r + gamma * max(Q(s', a'))) 其中,Q(s, a)是在状态s下采取动作a的预期回报,alpha是学习率,r是当前的奖励,gamma是折扣因子,s'是下一个状态,a'是在下一个状态下采取的最佳动作。 5. 聚合本地模型参数 在每轮迭代结束后,中央服务器会将所有参与者的本地模型参数进行聚合,生成一个全局模型,并将其发送给所有参与者。这样,每个参与者就可以使用全局模型来更新本地模型参数。 通过这种方式,我们可以实现一个联邦学习优化算法,它可以通过Q-learning算法学习和优化各个参与者的动作,从而提高全局模型的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CandyDingDing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值