使用Adaline神经网络识别印刷体数字

1.创建Adaline神经网络

       /**
	 * @author Ragty
	 * @param  设置Adaline神经网络
	 * @serialData 2018.4.24
	 * @param inputNeuralCount
	 * @param outputNeuralNetwork
	 */
	public void creatNetwork(int inputNeuralCount, int outputNeuralCount){
		
		//设置神经网络类型为Adaline
		this.setNetworkType(NeuralNetworkType.ADALINE);
		
		//建立输入神经元,表刺激
		NeuronProperties inputNeural = new NeuronProperties();
		inputNeural.setProperty("transferFunction", TransferFunctionType.LINEAR);
	
		//建立神经网络的输入层
		Layer inputLayer = LayerFactory.createLayer(inputNeuralCount, inputNeural);
		inputLayer.addNeuron(new BiasNeuron());
		this.addLayer(inputLayer);
		
		//建立输出神经元
		NeuronProperties outputNeural = new NeuronProperties();
		outputNeural.setProperty("transferFunction", TransferFunctionType.LINEAR);
		
		//创建输出层
		Layer outputLayer = LayerFactory.createLayer(outputNeuralCount, outputNeural);
		this.addLayer(outputLayer);
		
		//输入输出层全连接
		ConnectionFactory.fullConnect(inputLayer, outputLayer);
		NeuralNetworkFactory.setDefaultIO(this);
		
		//设置LMS算法
		//学习步长系数为0.05(由最速下降法引入,表示学习的速度)一般是在0.1或0.01这样的数量级
                //步长太大,不精准,步长太小,学习速度慢,易陷入局部最优
		//设置最大可接受误差为0.5  (LMS中不同于感知机,误差是连续的)
		//(w_new = w_old + 2aep) (b_new = b_old + 2ae)  >>LMS公式(步长系数a,省略常数2)
		LMS lms = new LMS();
		lms.setLearningRate(0.05);
		lms.setMaxError(0.5);
		this.setLearningRule(lms);
		
	}
	
	


2.实现Adaline的核心算法LMS
public LMS(){
		
	}
	
   /**
    *@author Ragty
    *@param  LMS核心算法
    *@serialData 2018.4.24
    *@核心公式  deltaWeight = learningRate * neuronError * input(learingRate是学习系数)
    */
	@Override
	protected void updateNetworkWeights(double[] outputError) {
		// TODO Auto-generated method stub
		int i = 0;
		
		//遍历每个神经元,修改权值
		for(Neuron neuron : neuralNetwork.getOutputNeurons()){
			neuron.setError(outputError[i]);
			this.updateNetworkWeights(neuron);
			i++;
		}
		
	}

	
	/**
	 * @author Ragty
	 * @param  迭代更新每个输入神经元的权值
	 * @serialData 2018.4.24
	 * @param neuron
	 */
	protected void updateNetworkWeights(Neuron neuron) {
		// TODO Auto-generated method stub
		//取得神经元误差
		double neuronError = neuron.getError();
		
		//根据所有神经元输入迭代学习
		for(Connection connection : neuron.getInputConnections()){
			//神经元的一个输入
			double input = connection.getInput();
			double weightChange = this.learningRate * neuronError * input;
			
			//更新权值
			Weight weight = connection.getWeight();
			weight.weightChange = weightChange;
			weight.value += weightChange;
		}
		
	}


3.实现Adaline感知机识别印刷体数字
public class AdalineDemo implements LearningEventListener{

	//设置输入神经元的个数为5*7=35个
	public final static int char_width = 5;
	public final static int char_height = 7;
	
	public static String[][] DIGITS = { 
	      { " OOO ",
	        "O   O",
	        "O   O",
	        "O   O",
	        "O   O",
	        "O   O",
	        " OOO "  },

	      { "  O  ",
	        " OO  ",
	        "O O  ",
	        "  O  ",
	        "  O  ",
	        "  O  ",
	        "  O  "  },

	      { " OOO ",
	        "O   O",
	        "    O",
	        "   O ",
	        "  O  ",
	        " O   ",
	        "OOOOO"  },

	      { " OOO ",
	        "O   O",
	        "    O",
	        " OOO ",
	        "    O",
	        "O   O",
	        " OOO "  },

	      { "   O ",
	        "  OO ",
	        " O O ",
	        "O  O ",
	        "OOOOO",
	        "   O ",
	        "   O "  },

	      { "OOOOO",
	        "O    ",
	        "O    ",
	        "OOOO ",
	        "    O",
	        "O   O",
	        " OOO "  },

	      { " OOO ",
	        "O   O",
	        "O    ",
	        "OOOO ",
	        "O   O",
	        "O   O",
	        " OOO "  },

	      { "OOOOO",
	        "    O",
	        "    O",
	        "   O ",
	        "  O  ",
	        " O   ",
	        "O    "  },

	      { " OOO ",
	        "O   O",
	        "O   O",
	        " OOO ",
	        "O   O",
	        "O   O",
	        " OOO "  },

	      { " OOO ",
	        "O   O",
	        "O   O",
	        " OOOO",
	        "    O",
	        "O   O",
	        " OOO "  } };
	
	
	public static void main(String[] args) {
		
		//设置Adaline神经网络输入节点为35个,输出节点为10个
		Adaline ada = new Adaline(char_width * char_height, DIGITS.length);
		
		//设置训练集为35个输入节点,10个输出节点
		DataSet ds = new DataSet(char_width * char_height, DIGITS.length);
		
		//设置训练集(前面是输入值,后面是期望值)
		for(int i = 0; i < DIGITS.length; i++ ){
           ds.addRow(creatTrainRow(DIGITS[i], i));
		}
		//监督训练过程
		ada.getLearningRule().addListener(new AdalineDemo());
		//训练该神经网络
		ada.learn(ds);
		
		//测试训练好的数据
		for(int i = 0; i < DIGITS.length; i++){
			ada.setInput(image2data(DIGITS[i]));
			ada.calculate();
			print(DIGITS[i]);
			System.out.print(maxIndex(ada.getOutput()));
			System.out.println();
		}
		
	}
	
	
	/**
	 * @author Ragty
	 * @param  设置这几个数字的训练集
	 * @serialData 2018.4.24
	 * @param image
	 * @param idealValue
	 * @return
	 */
	public static DataSetRow creatTrainRow(String[] image, int idealValue){
		double[] output = new double[DIGITS.length];
		
		//将训练集初始化
		for(int i = 0; i <DIGITS.length; i++)
			output[i] = -1;
		
		//输入数据
		double[] input = image2data(image);
		
		//用这样的方式来表示一个具体的数字(10个数字分为十个维度,表示哪个数字把哪个数字的维度设置为1)
		output[idealValue] = 1;
		//设置训练集以及期望值
		DataSetRow dsr = new DataSetRow(input, output);
		return dsr;
	}
	
	
	/**
	 * @author Ragty
	 * @param  将输入的二维数组转化为网络能够识别的格式(有字的地方全部转化为1,无字的地方转化为-1)
	 * @serialData 2018.4.24
	 * @param image
	 * @return
	 */
	public static double[] image2data(String[] image){
		double[] input = new double[char_width * char_height];
		
		for(int row = 0; row < char_height; row++){
			for(int col = 0; col < char_width; col++){
				int index = (row*char_width)+col;
				char ch = image[row].charAt(col);
				input[index] = ch == 'O'? 1 :-1;
			}
		}
		
		return input;
	}
	
	
	/**
	 * @author Ragty
	 * @param  识别输出数据为数字(采用竞争规则,在所有维度里,将最大的那个维度视为1,其余均为0)
	 * @param  即找到数组中最大值的索引下标(第一次从左边的条件进入)
	 * @serialData 2018.4.24
	 * @param data
	 * @return
	 */
	public static int maxIndex(double[] data){
		int result = -1;
		for(int i = 0; i < data.length; i++){
			if(result == -1 || data[i] > data[result]){
				result = i;
			}
		}
		return result;
	}
	
	
	/**
	 * @author Ragty
	 * @param  打印输出的打印字体
	 * @serialData 2018.4.24
	 * @param dIGITS2
	 */
	public static void print(String[] dIGITS2){
		
		for(int i = 0; i <dIGITS2.length; i++){
			if(i == dIGITS2.length-1){
				System.out.print(dIGITS2[i]+"===>");
			} else {
				System.out.println(dIGITS2[i]);
			}
		}
		
	}
	
	/**
	 * @param 监督训练
	 */
	@Override
	public void handleLearningEvent(LearningEvent event) {
		// TODO Auto-generated method stub
		IterativeLearning bp = (IterativeLearning)event.getSource();
        System.out.println("iterate:"+bp.getCurrentIteration()); 
        System.out.println(Arrays.toString(bp.getNeuralNetwork().getWeights()));
	}

	
}


4.识别结果


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值