人工智能之感知器网络

人工智能之感知器网络

感知器神经元模型


单感知器模型



上面的介绍是感知器神经元的原理,通过一个样本的输入P(p1,p2,,,,pn),输入的每个值与感知器之间都存在一个加权值W(w1,w2,,,,,wn),每个输入值对

应一个加权值,即是加权矩阵的列数,根据输出要求,可以确定加权矩阵中的行数,此外,为了使神经元的模拟更加符合实际,要加入一个偏置数B(b1,b2,,,,bm),

偏置矩阵的行数即是输出的个数,用来对输出进行矫正。

一个输出为:

Y[i] = WP + B[i];

为感知器需要通过反复的学习,反复的矫正加权矩阵W和偏置矩阵B来使感知器真正的具有其应有的作用。实际上,感知器就是为了训练得到符合要求的W和B矩阵。

矫正:

W[i] = W[i] + α*(Y[i] –T[i])*P[i];

B[i] = B[i] + α*(Y[i] –T[i]);

其中的α值为0~1的小数,我的α值取的是0.2。

T为目标值。

为了确定训练效果,这里需要定义一个误差率,误差这样来定义:

E = 1/2∑(Y - T)2

为了保证训练无限循环下去,我们需要规定一个循环次数最大值MAX,如果误差达到规定要求,直接跳出循环,最后的W和P即是我们需要的结果,可以用W和P来进行测试。

下面是主流程类Feeling:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Scanner;
/**
 * 
 * @author 41571
 *
 */
public class Feeling {
	private ArrayList<Input> inputList = new ArrayList<Input>();
	private double[][] weight;	//加权矩阵
	private double[] bais;	//偏置
	private double alpha = 0.2;	//学习率
	private int MAX = 200;
	private double sum;
	private double LIMIT = 0.001;
	public static void main(String[] args){
		new Feeling();
	}
	public Feeling(){
		System.out.println("训练样本有几组?");
		int sampleNum = Integer.parseInt(read());	//样本数
		System.out.println("每组样本有几个数据?");
		int dataNum = Integer.parseInt(read());	//每个样本数据数
		System.out.println("目标结果有几个?");
		int aimNum = Integer.parseInt(read()); //目标结果数目
		weight = new double[aimNum][dataNum];
		bais = new double[aimNum];
		for(int i = 0;i < weight.length;i ++){
			bais[i] = Math.random()*2-1;
			for(int j = 0;j <weight[i].length;j ++){
				weight[i][j] = Math.random()*2-1;
			}
		}
		for(int i = 0;i < sampleNum;i ++){		//实例化样本
			inputList.add(new Input(dataNum,aimNum));
		}
		readLine();
		for(int i = 0;i < inputList.size();i ++){
			inputList.get(i).show();
		}
		doing();
	}
	private void doing() {
		// TODO Auto-generated method stub
		int time=0;int a = 1;
		while(time <= MAX){
			for(int i = 0;i < inputList.size();i ++){
				time++;
				if(time>MAX) {a = 0;break;}
				inputList.get(i).setMiss(weight, bais);
				sum = inputList.get(i).getTotal();
				//System.out.println("误差平方和为:"+sum);
				if(sum <= LIMIT){
					System.out.println("-----------------------------------------");
					System.out.println("经过"+time+"次学习,误差率为:"+sum);
					a = 0;
					System.out.println("-----------------------------------------");
					System.out.println("加权矩阵为:");
					for(int m=0;m < weight.length;m ++){
						for(int j=0;j < weight[m].length;j ++){
							System.out.print(weight[m][j]+" ");
						}
						System.out.println("");
					}
					System.out.println("偏置矩阵为:");
					for(int m=0;m < bais.length;m++){
						System.out.print(bais[m]+" ");
					}
					System.out.println("");
					break;
				}
				inputList.get(i).setWeight(weight, alpha);
				inputList.get(i).setBais(bais, alpha);
			}
			if(a == 0){
				break;
			}
		}
		if(time>MAX)
		System.out.println("尴尬!!!超出次数了。");
	}
	public void readLine(){
		String filePath = "txt.txt";
		int n = 0;
		ArrayList<Double> put = new ArrayList<Double>();
		try{		//读取文件
			File file = new File(filePath);
			if(file.isFile()&&file.exists()){
				InputStreamReader read = new InputStreamReader(new FileInputStream(file),"gbk");
				BufferedReader reader = new BufferedReader(read);
				String line;
				while((line = reader.readLine())!=null){
					String str = "";
					String inp = line+' ';
					char[] inputs = inp.toCharArray();
					for(int i = 0;i < inputs.length;i ++){
						if(inputs[i]!=' '){
							str += inputs[i];
						}else{
							put.add(Double.parseDouble(str));
							str = "";
						}
					}
					double[] puts = new double[put.size()];
					for(int i = 0;i < put.size();i ++){
						puts[i] = put.get(i);
					}
					if(n<inputList.size())
						inputList.get(n%4).setData(puts);
					else
						inputList.get(n%4).setAim(puts);
					n++;
					puts = null;
					put.clear();
				}
				read.close();
			}
		}catch(Exception e){
			e.printStackTrace();
		}
	}
	public String read(){
		Scanner in = new Scanner(System.in);
		return in.nextLine();
	}
}

还有一个存储和计算的Input类:

/**
 * 
 * @author 41571
 *
 */
public class Input {
	private double[] data;
	private double[] aim;
	private double[] error;//每个目标的误差
	private double total;//误差率平方和
	
	public double getTotal() {
		return total;
	}
	public Input(int a,int b){
		data = new double[a];
		aim = new double[b];
		error = new double[b];
	}
	public double[] getData() {
		return data;
	}
	public void setData(double[] data) {
		this.data = data;
	}
	public double[] getAim() {
		return aim;
	}
	public void setAim(double[] aim) {
		this.aim = aim;
	}
	public void setMiss(double[][] weight,double[] bais){
		double[] temp = new double[aim.length];
		total=0;
		for(int i = 0;i < weight.length;i ++){
			for(int j = 0;j < weight[i].length;j ++){
				temp[i] += weight[i][j]*data[j];
			}
			temp[i]+=bais[i];
			error[i]=aim[i]-temp[i];
			total+=error[i]*error[i];
			total = (double)1/2*total;
		}
	}
	public void setWeight(double[][] weight,double alpha){
		for(int i = 0;i < weight.length;i ++){
			for(int j = 0;j < weight[i].length;j ++){
				weight[i][j] = weight[i][j] + alpha*error[i]*data[j];
			}
		}
	}
	public void setBais(double[] bais,double alpha){
		for(int i = 0;i < bais.length;i ++){
			bais[i] = bais[i] + alpha*error[i];
		}
	}
	public void show(){
		System.out.print("数据:");
		for(int i = 0;i < data.length;i ++){
			System.out.print(data[i]+"  ");
		}
		System.out.print("\n目标:");
		for(int i = 0;i < aim.length;i ++){
			System.out.print(aim[i]+"  ");
		}
		System.out.println("");
	}
}


  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值