Java神经网络实现
Weight .java
package neural;
import java.io.File;
import java.io.RandomAccessFile;
import java.util.Random;
public class Weight {
static Random random=new Random(System.currentTimeMillis());
public static final float alpha=0.01f,beta=0.9f;//学习率,数值取0~1之间,数值越大训练得越快,数值越小训练的越精确
public static float Extremely_fast_Sigmoid(float x)//GPU版Sigmoid可用于Opencl GPU内核
{
float t=x<0?-x:x;
t=1.0f+t/256f;
for(int i=0;i<8;i++)
t*=t;
float y=x<0?t:1.0f/t;
return 1.0f/(1.0f+y);
}
public static float Exp(float x)
{
float t=x<0?-x:x;
t=1.0f+t/256f;
for(int i=0;i<8;i++)
t*=t;
return x<0?1.0f/t:t;
}
public static float Sigmoid(float x)//CPU版Sigmoid
{
return 1.0f/(1.0f+(float)Math.exp(-x));
}
public static float dSigmoid(float y)//Sigmoid导函数
{
return y*(1-y);
}
public static Callback callback;
public float[] V,Y,Delta;
public float[][] W,DW;
public Weight X,next;
int SelfLen;
public boolean usedw,use_cross_entropy,use_Softmax,use_RelU,use_Dropout;
public boolean[] relu_v;
public void useSoftmax()
{
System.err.println("Softmax被启用\n交叉熵被启用");
use_cross_entropy=true;
use_Softmax=true;
}
public Weight useDropout()
{
System.err.println("随机节点丢弃被启用");
use_Dropout=true;
return this;
}
public Weight useRelU()
{
System.err.println("RelU被启用");
relu_v=new boolean[this.SelfLen];
use_RelU=true;
return this;
}
public static int[] randperm(int max, int count)
{
int temp[] = new int[count];
boolean nothave;
for (int i = 0; i < count; i++)
{
while (true)
{
int t = random.nextInt();
t=t<0?-t:t;
t=t% max;
nothave = true;
for (int j = 0; j < i; j++)
if (t == temp[j])
{
nothave = false;
break;
}
if (nothave)
{
temp[i] = t;
break;
}
}
}
return temp;
}
public void Dropout(float ratio)
{
float round = this.SelfLen*(1 - ratio);
for (int i = 0; i <this.SelfLen; i++)
this.V[i]=0;
int num = (round - (float)(int)round >= 0.5f ? (int)round + 1 : (int)round);
int idx[] = Weight.randperm(this.SelfLen, num);
for (int i = 0; i < num; i++)
{
this.V[idx[i]] = (1 / (1 - ratio));
}
for (int i = 0; i <this.SelfLen; i++)
this.Y[i] *= this.V[i];
}
public Weight(Weight X,int SelfLen,boolean usedw)
{
this.X=X;
this.SelfLen=SelfLen;
Y=new float[SelfLen];
V=new float[SelfLen];
Delta=new float[SelfLen];
W=new float[this.X.SelfLen][SelfLen];
this.usedw=usedw;
if(usedw)
{
DW=new float[this.X.SelfLen][SelfLen];
System.err.println("Momentum is enabled, may consume excessive resources. Size="+X.SelfLen+'x'+this.SelfLen);
}
}
public Weight(float[] Y) {this.Y=Y;this.SelfLen=Y.length;this.X=null;}
public Weight(int SelfLen,boolean __new__)
{
if(__new__)
this.Y=new float[SelfLen];
this.SelfLen=SelfLen;
this.X=null;
}
public void quote_Y(float[] Y) {this.Y=Y;}
public void copy_Y(float[] Y) {
if(this.Y==null)this.Y=new float[this.SelfLen];
for(int i=0;i<this.SelfLen;i++)
this.Y[i]=Y[i];
}
public void ReStartWeight()
{
if(callback!=null)
{
for(int i=0;i<W.length;i++)
for(int j=0;j<W[0].length;j++)
W[i][j]=Weight.callback.getrandnum();
}
}
public Weight add(int SelfLen,boolean usedw)
{
Weight tmp=new Weight(this,SelfLen,usedw);
this.next=tmp;
tmp.ReStartWeight();
return tmp;
}
public void add(Weight weight)
{
this.next=weight;
}
public void setX(Weight weight)
{
this.X=weight;
}
public void calculate_Y_Sigmoid()
{
for(int i=0;i<this.SelfLen;i++)
this.V[i]=0;
for(int i=0;i<W.length;i++)
for(int j=0;j<this.SelfLen;j++)
{
this.V[j]+=this.W[i][j]*this.X.Y[i];
}
for(int i=0;i<this.SelfLen;i++)
this.Y[i]=Extremely_fast_Sigmoid(V[i]);
}
public void calculate_Y_RelU()
{
for(int i=0;i<this.SelfLen;i++)
this.V[i]=0;
for(int i=0;i<W.length;i++)
for(int j=0;j<this.SelfLen;j++)
{
this.V[j]+=this.W[i][j]*this.X.Y[i];
}
for(int i=0;i<this.SelfLen;i++)
this.Y[i]=this.V[i]>0?this.V[i]:0;
}
public void calculate_Y_Softmax()
{
for(int i=0;i<this.SelfLen;i++)
this.V[i]=0;
for(int i=0;i<this.W.length;i++)
for(int j=0;j<this.SelfLen;j++)
{
this.V[j]+=this.W[i][j]*this.X.Y[i];
}
float sum = 0;
float t=0;
for (int i = 0; i < this.SelfLen; i++)
{
t=Exp(this.V[i]);
this.V[i] = t;
sum +=t;
}
for (int i = 0; i < this.SelfLen; i++)
this.Y[i] = this.V[i] / sum;
}
public void calculate_Delta_Sigmoid(float D[])throws Exception
{
if(this.SelfLen!=D.length)
throw new Exception("inconsistent length");
for(int i=0;i<this.SelfLen;i++)
this.Delta[i]=dSigmoid(this.Y[i])*(D[i]-this.Y[i]);
}
public void calculate_Delta_cross_entropy(float D[])throws Exception
{
if(this.SelfLen!=D.length)
throw new Exception("inconsistent length");
for(int i=0;i<this.SelfLen;i++)
this.Delta[i]=(D[i]-this.Y[i]);
}
public void calculate_Delta_Sigmoid()
{
for(int i=0;i<this.SelfLen;i++)
this.V[i]=0;
for(int i=0;i<this.SelfLen;i++)
for(int j=0;j<this.next.SelfLen;j++)
{
this.V[i]+=this.next.W[i][j]*this.next.Delta[j];
}
for(int i=0;i<this.SelfLen;i++)
this.Delta[i]=dSigmoid(this.Y[i])*this.V[i];
}
public void calculate_Delta_RelU()
{
for(int i=0;i<this.SelfLen;i++)
{
this.relu_v[i]=this.V[i]>0;
this.V[i]=0;
}
for(int i=0;i<this.SelfLen;i++)
for(int j=0;j<this.next.SelfLen;j++)
{
this.V[i]+=this.next.W[i][j]*this.next.Delta[j];
}
for(int i=0;i<this.SelfLen;i++)
this.Delta[i]=this.relu_v[i]?this.V[i]:0;
}
public void Delta_Weight()
{
for(int i=0;i<W.length;i++)
for(int j=0;j<this.SelfLen;j++)
this.W[i][j]+=Weight.alpha*this.Delta[j]*this.X.Y[i];
}
public void Delta_Dw_Weight()
{
if(!this.usedw)return;
for(int i=0;i<W.length;i++)
for(int j=0;j<this.SelfLen;j++)
{
this.DW[i][j]=(Weight.alpha*this.Delta[j]*this.X.Y[i])+Weight.beta*this.DW[i][j];
this.W[i][j]+=this.DW[i][j];
}
}
public static void Save(Weight head,String path) throws Exception
{
File file=new File(path);
if(file.exists()&&file.isDirectory())throw new Exception("该位置不能保存文件");
RandomAccessFile F=new RandomAccessFile(file, "rw");
Weight tmp=null;
tmp=head;
while(true){
tmp=tmp.next;
if(tmp==null)
break;
for(int i=0;i<tmp.X.SelfLen;i++)
for(int j=0;j<tmp.SelfLen;j++)
{
F.writeFloat(tmp.W[i][j]);
}
}
F.close();
}
public static void Load(Weight head,String path) throws Exception
{
File file=new File(path);
if(!file.exists()||file.isDirectory())throw new Exception("文件打开失败!");
RandomAccessFile F=new RandomAccessFile(file, "rw");
Weight tmp=null;
tmp=head;
while(true){
tmp=tmp.next;
if(tmp==null)
break;
for(int i=0;i<tmp.X.SelfLen;i++)
for(int j=0;j<tmp.SelfLen;j++)
{
tmp.W[i][j]=F.readFloat();
}
}
F.close();
}
public static void ReStartDW(Weight head)
{
Weight tmp=null;
tmp=head;
while(true){
tmp=tmp.next;
if(tmp==null)
break;
if(tmp.usedw)
{
for(int i=0;i<tmp.X.SelfLen;i++)
for(int j=0;j<tmp.SelfLen;j++)
{
tmp.DW[i][j]=0;
}
}
}
}
public static void Train(Weight head,float D[]) throws Exception
{
Weight tmp=null,RS=null;
tmp=head;
while(true){
RS=tmp;
tmp=tmp.next;
if(tmp==null)
break;
if(tmp.use_Softmax)
tmp.calculate_Y_Softmax();
else if(tmp.use_RelU)
tmp.calculate_Y_RelU();
else
tmp.calculate_Y_Sigmoid();
if(tmp.use_Dropout)
tmp.Dropout(0.2f);
}
tmp=RS;
if(tmp.use_cross_entropy)
tmp.calculate_Delta_cross_entropy(D);
else
tmp.calculate_Delta_Sigmoid(D);
while(true){
tmp=tmp.X;
if(tmp.W==null)break;
if(tmp.use_RelU)
tmp.calculate_Delta_RelU();
else
tmp.calculate_Delta_Sigmoid();
}
tmp=head;
while(true){
tmp=tmp.next;
if(tmp==null)
break;
if(tmp.usedw)
tmp.Delta_Dw_Weight();
else
tmp.Delta_Weight();
}
}
public static float[] Test(Weight head)
{
Weight tmp=null,RS=null;
tmp=head;
while(true){
RS=tmp;
tmp=tmp.next;
if(tmp==null)
break;
if(tmp.use_Softmax)
tmp.calculate_Y_Softmax();
else if(tmp.use_RelU)
tmp.calculate_Y_RelU();
else
tmp.calculate_Y_Sigmoid();
}
return RS.Y;
}
}