Java神经网络实现

4 篇文章 0 订阅
2 篇文章 0 订阅

Java神经网络实现

Weight .java

package neural;

import java.io.File;
import java.io.RandomAccessFile;
import java.util.Random;

public class Weight {
	static Random random=new Random(System.currentTimeMillis());
	public static final float alpha=0.01f,beta=0.9f;//学习率,数值取0~1之间,数值越大训练得越快,数值越小训练的越精确
	public static float Extremely_fast_Sigmoid(float x)//GPU版Sigmoid可用于Opencl GPU内核
	{
		float t=x<0?-x:x;
		t=1.0f+t/256f;
		for(int i=0;i<8;i++)
			t*=t;
		float y=x<0?t:1.0f/t;
		return 1.0f/(1.0f+y);
	}
	public static float Exp(float x)
	{
		float t=x<0?-x:x;
		t=1.0f+t/256f;
		for(int i=0;i<8;i++)
			t*=t;
		return x<0?1.0f/t:t;
	}
	public static float Sigmoid(float x)//CPU版Sigmoid
	{
		return 1.0f/(1.0f+(float)Math.exp(-x));
	}
	public static float dSigmoid(float y)//Sigmoid导函数
	{
		return y*(1-y);
	}
	public static Callback callback;
	public float[] V,Y,Delta;
	public float[][] W,DW;
	public Weight X,next;
	int SelfLen;
	public boolean usedw,use_cross_entropy,use_Softmax,use_RelU,use_Dropout;
	public boolean[] relu_v;
	public void useSoftmax()
	{
		System.err.println("Softmax被启用\n交叉熵被启用");
		use_cross_entropy=true;
		use_Softmax=true;
	}
	public Weight useDropout()
	{
		System.err.println("随机节点丢弃被启用");
		use_Dropout=true;
		return this;
	}
	public Weight useRelU()
	{
		System.err.println("RelU被启用");
		relu_v=new boolean[this.SelfLen];
		use_RelU=true;
		return this;
	}
	public static int[] randperm(int max, int count)
	{
		int temp[] = new int[count];
		boolean nothave;
		for (int i = 0; i < count; i++)
		{
			while (true)
			{
				int t = random.nextInt();
				t=t<0?-t:t;
				t=t% max;
				nothave = true;
				for (int j = 0; j < i; j++)
					if (t == temp[j])
					{
						nothave = false;
						break;
					}
				if (nothave)
				{
					temp[i] = t;
					break;
				}
			}
		}
		return temp;
	}
	public void Dropout(float ratio)
	{
		float round = this.SelfLen*(1 - ratio);
		for (int i = 0; i <this.SelfLen; i++)
			this.V[i]=0;
		int num = (round - (float)(int)round >= 0.5f ? (int)round + 1 : (int)round);
		int idx[] = Weight.randperm(this.SelfLen, num);
		for (int i = 0; i < num; i++)
		{
			this.V[idx[i]] = (1 / (1 - ratio));
		}
		for (int i = 0; i <this.SelfLen; i++)
			this.Y[i] *= this.V[i];
	}
	public Weight(Weight X,int SelfLen,boolean usedw) 
	{
		this.X=X;
		this.SelfLen=SelfLen;
		Y=new float[SelfLen];
		V=new float[SelfLen];
		Delta=new float[SelfLen];
		W=new float[this.X.SelfLen][SelfLen];
		this.usedw=usedw;
		if(usedw)
		{
			DW=new float[this.X.SelfLen][SelfLen];
			System.err.println("Momentum is enabled, may consume excessive resources. Size="+X.SelfLen+'x'+this.SelfLen);
		}
	}
	public Weight(float[] Y) {this.Y=Y;this.SelfLen=Y.length;this.X=null;}
	public Weight(int SelfLen,boolean __new__)
	{
		if(__new__)
			this.Y=new float[SelfLen];
		this.SelfLen=SelfLen;
		this.X=null;
	}
	public void quote_Y(float[] Y) {this.Y=Y;}
	public void copy_Y(float[] Y) {
		if(this.Y==null)this.Y=new float[this.SelfLen];
		for(int i=0;i<this.SelfLen;i++)
			this.Y[i]=Y[i];
	}
	public void ReStartWeight()
	{
		if(callback!=null)
		{
			for(int i=0;i<W.length;i++)
				for(int j=0;j<W[0].length;j++)
					W[i][j]=Weight.callback.getrandnum();
		}
	}
	public Weight add(int SelfLen,boolean usedw)
	{
		Weight tmp=new Weight(this,SelfLen,usedw);
		this.next=tmp;
		tmp.ReStartWeight();
		return tmp;
	}
	public void add(Weight weight)
	{
		this.next=weight;
	}
	public void setX(Weight weight)
	{
		this.X=weight;
	}
	public void calculate_Y_Sigmoid()
	{
		for(int i=0;i<this.SelfLen;i++)
			this.V[i]=0;
		for(int i=0;i<W.length;i++)
			for(int j=0;j<this.SelfLen;j++)
			{
				this.V[j]+=this.W[i][j]*this.X.Y[i];
			}
		for(int i=0;i<this.SelfLen;i++)
			this.Y[i]=Extremely_fast_Sigmoid(V[i]);
	}
	public void calculate_Y_RelU()
	{
		for(int i=0;i<this.SelfLen;i++)
			this.V[i]=0;
		for(int i=0;i<W.length;i++)
			for(int j=0;j<this.SelfLen;j++)
			{
				this.V[j]+=this.W[i][j]*this.X.Y[i];
			}
		for(int i=0;i<this.SelfLen;i++)
			this.Y[i]=this.V[i]>0?this.V[i]:0;
	}
	public void calculate_Y_Softmax()
	{
		for(int i=0;i<this.SelfLen;i++)
			this.V[i]=0;
		for(int i=0;i<this.W.length;i++)
			for(int j=0;j<this.SelfLen;j++)
			{
				this.V[j]+=this.W[i][j]*this.X.Y[i];
			}
		float sum = 0;
		float t=0;
		for (int i = 0; i < this.SelfLen; i++)
		{
			t=Exp(this.V[i]);
			this.V[i] = t;
			sum +=t;
		}
		for (int i = 0; i <  this.SelfLen; i++)
			this.Y[i] = this.V[i] / sum;
	}
	public void calculate_Delta_Sigmoid(float D[])throws Exception
	{
		if(this.SelfLen!=D.length)
			throw new Exception("inconsistent length");
		for(int i=0;i<this.SelfLen;i++)
			this.Delta[i]=dSigmoid(this.Y[i])*(D[i]-this.Y[i]);
	}
	public void calculate_Delta_cross_entropy(float D[])throws Exception
	{
		if(this.SelfLen!=D.length)
			throw new Exception("inconsistent length");
		for(int i=0;i<this.SelfLen;i++)
			this.Delta[i]=(D[i]-this.Y[i]);
	}
	public void calculate_Delta_Sigmoid()
	{
		for(int i=0;i<this.SelfLen;i++)
			this.V[i]=0;
		for(int i=0;i<this.SelfLen;i++)
			for(int j=0;j<this.next.SelfLen;j++)
			{
				this.V[i]+=this.next.W[i][j]*this.next.Delta[j];
			}
		for(int i=0;i<this.SelfLen;i++)
			this.Delta[i]=dSigmoid(this.Y[i])*this.V[i];
	}
	public void calculate_Delta_RelU()
	{
		for(int i=0;i<this.SelfLen;i++)
		{
			this.relu_v[i]=this.V[i]>0;
			this.V[i]=0;
		}
		for(int i=0;i<this.SelfLen;i++)
			for(int j=0;j<this.next.SelfLen;j++)
			{
				this.V[i]+=this.next.W[i][j]*this.next.Delta[j];
			}
		for(int i=0;i<this.SelfLen;i++)
			this.Delta[i]=this.relu_v[i]?this.V[i]:0;
	}
	public void Delta_Weight()
	{
		for(int i=0;i<W.length;i++)
			for(int j=0;j<this.SelfLen;j++)
				this.W[i][j]+=Weight.alpha*this.Delta[j]*this.X.Y[i];
	}
	public void Delta_Dw_Weight()
	{
		if(!this.usedw)return;
		for(int i=0;i<W.length;i++)
			for(int j=0;j<this.SelfLen;j++)
			{
				this.DW[i][j]=(Weight.alpha*this.Delta[j]*this.X.Y[i])+Weight.beta*this.DW[i][j];
				this.W[i][j]+=this.DW[i][j];
			}
	}
	public static void Save(Weight head,String path) throws Exception
	{
		File file=new File(path);
		if(file.exists()&&file.isDirectory())throw new Exception("该位置不能保存文件");
		RandomAccessFile F=new RandomAccessFile(file, "rw");
		Weight tmp=null;
		tmp=head;
		while(true){
			tmp=tmp.next;
			if(tmp==null)
				break;
			for(int i=0;i<tmp.X.SelfLen;i++)
				for(int j=0;j<tmp.SelfLen;j++)
				{
					F.writeFloat(tmp.W[i][j]);
				}
		}
		F.close();
	}
	public static void Load(Weight head,String path) throws Exception
	{
		File file=new File(path);
		if(!file.exists()||file.isDirectory())throw new Exception("文件打开失败!");
		RandomAccessFile F=new RandomAccessFile(file, "rw");
		Weight tmp=null;
		tmp=head;
		while(true){
			tmp=tmp.next;
			if(tmp==null)
				break;
			for(int i=0;i<tmp.X.SelfLen;i++)
				for(int j=0;j<tmp.SelfLen;j++)
				{
					tmp.W[i][j]=F.readFloat();
				}
		}
		F.close();
	}
	public static void ReStartDW(Weight head)
	{
		Weight tmp=null;
		tmp=head;
		while(true){
			tmp=tmp.next;
			if(tmp==null)
				break;
			if(tmp.usedw)
			{
			for(int i=0;i<tmp.X.SelfLen;i++)
				for(int j=0;j<tmp.SelfLen;j++)
				{
					tmp.DW[i][j]=0;
				}
			}
		}
	}
	public static void Train(Weight head,float D[]) throws Exception
	{
		Weight tmp=null,RS=null;
		tmp=head;
		while(true){
			RS=tmp;
			tmp=tmp.next;
			if(tmp==null)
				break;
			if(tmp.use_Softmax)
				tmp.calculate_Y_Softmax();
			else if(tmp.use_RelU)
				tmp.calculate_Y_RelU();
			else
				tmp.calculate_Y_Sigmoid();
			if(tmp.use_Dropout)
				tmp.Dropout(0.2f);
		}
		tmp=RS;
		if(tmp.use_cross_entropy)
			tmp.calculate_Delta_cross_entropy(D);
		else
			tmp.calculate_Delta_Sigmoid(D);
		while(true){
			tmp=tmp.X;
			if(tmp.W==null)break;
			if(tmp.use_RelU)
				tmp.calculate_Delta_RelU();
			else
				tmp.calculate_Delta_Sigmoid();
		}
		tmp=head;
		while(true){
			tmp=tmp.next;
			if(tmp==null)
				break;
			if(tmp.usedw)
				tmp.Delta_Dw_Weight();
			else
				tmp.Delta_Weight();
		}
	}
	public static float[] Test(Weight head)
	{
		Weight tmp=null,RS=null;
		tmp=head;
		while(true){
			RS=tmp;
			tmp=tmp.next;
			if(tmp==null)
				break;
			if(tmp.use_Softmax)
				tmp.calculate_Y_Softmax();
			else if(tmp.use_RelU)
				tmp.calculate_Y_RelU();
			else
				tmp.calculate_Y_Sigmoid();
		}
		return RS.Y;
	}
}


  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值