一层简单人工神经网络的Java实现




2、数据类

import java.util.Arrays;

public class Data {
	double[] vector;
	int dimention;
	int type;
	public double[] getVector() {
		return vector;
	}
	public void setVector(double[] vector) {
		this.vector = vector;
	}
	public int getDimention() {
		return dimention;
	}
	public void setDimention(int dimention) {
		this.dimention = dimention;
	}
	public int getType() {
		return type;
	}
	public void setType(int type) {
		this.type = type;
	}
	public Data(double[] vector, int dimention, int type) {
		super();
		this.vector = vector;
		this.dimention = dimention;
		this.type = type;
	}
	public Data() {
	}
	@Override
	public String toString() {
		return "Data [vector=" + Arrays.toString(vector) + ", dimention=" + dimention + ", type=" + type + "]";
	}
	
}
3、简单人工神经网络

package cn.edu.hbut.chenjie;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.JFreeChart;
import org.jfree.data.xy.DefaultXYDataset;
import org.jfree.ui.RefineryUtilities;


public class ANN2 {
	private double eta;//学习率
	private int n_iter;//权重向量w[]训练次数
	private List<Data> exercise;//训练数据集
	private double w0 = 0;//阈值
	private double x0 = 1;//固定值
	private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3
	private int testSum = 0;//测试数据总数
	private int error = 0;//错误次数
	DefaultXYDataset xydataset = new DefaultXYDataset();
	
	/**
	 * 向图表中增加同类型的数据
	 * @param type 类型
	 * @param a 所有数据的第一个分量
	 * @param b 所有数据的第二个分量
	 */
	public void add(String type,double[] a,double[] b)
	{
		double[][] data = new double[2][a.length];
		for(int i=0;i<a.length;i++)
		{
			data[0][i] = a[i];
			data[1][i] = b[i];
		}
	    xydataset.addSeries(type, data);  
	}
	
	/**
	 * 画图
	 */
	public void draw()
	{
        JFreeChart jfreechart = ChartFactory.createScatterPlot("exercise", "x1", "x2", xydataset);
    	ChartFrame frame = new ChartFrame("训练数据", jfreechart);
		frame.pack();
		RefineryUtilities.centerFrameOnScreen(frame);
		frame.setVisible(true);
	}
	
	public static void main(String[] args)
	{
		ANN2 ann2 = new ANN2(0.001,100);//构造人工神经网络
		
		List<Data> exercise = new ArrayList<Data>();//构造训练集
		
		//人工模拟1000条训练数据 ,分界线为x2=x1+0.5
		for(int i=0;i<1000000;i++)
		{
			Random rd = new Random();
			double x1 = rd.nextDouble();//随机产生一个分量
			double x2 = rd.nextDouble();//随机产生另一个分量
			double[] da = {x1,x2};//产生数据向量
			Data d = new Data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据
			exercise.add(d);//将训练数据加入训练集
		}
		
		int sum1 = 0;//记录类型1的训练记录数
		int sum2 = 0;//记录类型-1的训练记录数
		for(int i = 0; i < exercise.size(); i++)
		{
			if(exercise.get(i).getType()==1)
				sum1++;
			else if(exercise.get(i).getType()==-1)
				sum2++;
		}
		double[] x1 = new double[sum1];
		double[] y1 = new double[sum1];
		double[] x2 = new double[sum2];
		double[] y2 = new double[sum2];
		int index1 = 0;
		int index2 = 0;
		for(int i = 0; i < exercise.size(); i++)
		{
			if(exercise.get(i).getType()==1)
			{
				x1[index1] = exercise.get(i).vector[0];
				y1[index1++] = exercise.get(i).vector[1];
			}
			else if(exercise.get(i).getType()==-1)
			{
				x2[index2] = exercise.get(i).vector[0];
				y2[index2++] = exercise.get(i).vector[1];
			}
		}
		
		ann2.add("1", x1, y1);
		ann2.add("-1", x2, y2);
		ann2.draw();
		
		ann2.input(exercise);//将训练集输入人工神经网络
		
		ann2.fit();//训练
		
		ann2.showWeigths();//显示权重向量
		
		
		//人工生成一千条测试数据
		for(int i=0;i<10000;i++)
		{
			Random rd = new Random();
			double x1_ = rd.nextDouble();
			double x2_ = rd.nextDouble();
			double[] da = {x1_,x2_};
			Data test = new Data(da, 2, x2_ > x1_+0.5 ? 1 : -1);
			ann2.predict(test);//测试
		}
		
		System.out.println("总共测试" + ann2.testSum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testSum * 100 + "%");
	}
	
	/**
	 * 
	 * @param eta 学习率
	 * @param n_iter 权重分量学习次数
	 */
	public ANN2(double eta, int n_iter) {
		this.eta = eta;
		this.n_iter = n_iter;
	}


	/**
	 * 输入训练集到人工神经网络
	 * @param exercise
	 */
	private void input(List<Data> exercise) {
		this.exercise = exercise;//保存训练集
		weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1
		weights[0] = w0;//权重向量第一个分量为w0
		for(int i = 1; i < weights.length; i++)
			weights[i] = 0;//其余分量初始化为0
	}
	
	
	private void fit() {
		for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次
		{
			for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练
			{
				int real_result = exercise.get(j).type;//y
				int calculate_result = CalculateResult(exercise.get(j));//y'
				double delta0 = eta * (real_result - calculate_result);//计算阈值更新
				w0 += delta0;//阈值更新
				weights[0] = w0;//更新w[0]
				for(int k = 0; k < exercise.get(j).getDimention(); k++)//更新权重向量其它分量
				{
					double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];
					//Δw=η*(y-y')*X
					weights[k+1] += delta;
					//w=w+Δw
				}
				
			}
		}
	}

	private int CalculateResult(Data data) {
		double z = w0 * x0;
		for(int i = 0; i < data.dimention; i++)
			z += data.vector[i] * weights[i+1];
		//z=w0x0+w1x1+...+WmXm
		//激活函数
		if(z>=0)
			return 1;
		else
			return -1;
	}

	private void showWeigths()
	{
		for(double w : weights)
			System.out.println(w);
	}

	private void predict(Data data) {
		int type = CalculateResult(data);
		if(type == data.getType())
		{
			//System.out.println("预测正确");
		}
		else
		{
			//System.out.println("预测错误");
			error ++;
		}
		testSum ++;
	}

	
}

运行结果:

-0.22000000000000017
-0.4416843982815453
0.442444202054685
总共测试10000条数据,有17条错误,错误率:0.16999999999999998%




  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值