2018华为软件精英挑战赛2

预测部分,我们使用的是GRU,也就是简化版的LSTM。其实根据我们的经验,这次赛事使用神经网络效果并不好,因为神经网络需要大量的数据喂食和训练,而这次比赛提供的训练数据量太少了。以下把我们的代码贴出来,大家一起讨论。代码很大程度上参考了github上的https://github.com/Root-lee/Put_Flavors_SA和https://github.com/lipiji/JRNN,在此非常感谢!

我们认为,作为神经网络训练来说,最主要的是分三部分:第一,数据的预处理。这里的数据预处理指的是,通过一些算法,把一些特定点作为异常点或者噪点抹去(例如一个特别高的点)以及对现有的数据的特征进行深入挖掘,提取出更多有效的特征,从而提高训练的有效性和准确性,以及进行一些后续处理。至少,对于这次比赛来说,这部分非常重要,对预测的准确性有极大的影响。

一、数据预处理的去异常点部分

对于神经网络训练来说,不同的数据通常有很不一样的特征。我们认为,要确定怎么去异常点,首先要通过人眼进行观察,从而有个大致的思路。我们因此把比赛官方给我们测试用的某个训练数据集,自己写了代码,画了图,进行观察。官方的训练数据集是每天的一些虚拟机申请数据,如下:


这里面:1、头尾的序号和具体时间是没用的,只有虚拟机名字和日期是有用数据。而虚拟机名字又有些是超出了我们需要预测的范围的没用数据,要剔除。另外,日期也不一定连续,需要用java的日期判断函数把不连续的日期补上。

//这份文件,把所有需要的数据都提炼出来
class FileUtil//官方提供的那个读取文件的类,我把它变成了静态内部类
{
    public static String[] read(final String filePath, final Integer spec)
    {
        File file = new File(filePath);
        // 当文件不存在或者不可读时
        if ((!isFileExists(file)) || (!file.canRead()))
        {
            System.out.println("file [" + filePath + "] is not exist or cannot read!!!");
            return null;
        }
        
        List<String> lines = new LinkedList<String>();
        BufferedReader br = null;
        FileReader fb = null;
        try
        {
            fb = new FileReader(file);
            br = new BufferedReader(fb);

            String str = null;
            int index = 0;
            while (((spec == null) || index++ < spec) && (str = br.readLine()) != null)
            {
                lines.add(str);
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
        finally
        {
            closeQuietly(br);
            closeQuietly(fb);
        }

        return lines.toArray(new String[lines.size()]);
    }

    public static int write(final String filePath, final String[] contents, final boolean append)
    {
        File file = new File(filePath);
        if (contents == null)
        {
            System.out.println("file [" + filePath + "] invalid!!!");
            return 0;
        }

        
        if (isFileExists(file) && (!file.canRead()))
        {
            return 0;
        }

        FileWriter fw = null;
        BufferedWriter bw = null;
        try
        {
            if (!isFileExists(file))
            {
                file.createNewFile();
            }

            fw = new FileWriter(file, append);
            bw = new BufferedWriter(fw);
            for (String content : contents)
            {
                if (content == null)
                {
                    continue;
                }
                bw.write(content);
                bw.newLine();
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
            return 0;
        }
        finally
        {
            closeQuietly(bw);
            closeQuietly(fw);
        }

        return 1;
    }

    private static void closeQuietly(Closeable closeable)
    {
        try
        {
            if (closeable != null)
            {
                closeable.close();
            }
        }
        catch (IOException e)
        {
        }
    }

    private static boolean isFileExists(final File file)
    {
        if (file.exists() && file.isFile())
        {
            return true;
        }

        return false;
    }

}


public class datedataoutput {

	public static void main(String[] args) throws ParseException {
		// TODO 自动生成的方法存根
		String filepath="F:\\我的亿方云同步\\FangCloudV2\\个人文件\\华为比赛相关\\20180322datedatasumury\\originaldata.txt";
		String[] readresult=FileUtil.read(filepath, null);//日期是从前往后的
		ArrayList<ArrayList<String>> alldatalistdate=new ArrayList<ArrayList<String>>();
		ArrayList<ArrayList<Integer>> alldatalistnum=new ArrayList<ArrayList<Integer>>();
		//----------------先把第一天的放进去
		String tempday=readresult[0].split("\\s+")[2];
		String tempflavor=readresult[0].split("\\s+")[1];
		
		alldatalistdate.add(new ArrayList<String>());
		alldatalistnum.add(new ArrayList<Integer>());
		
		alldatalistdate.get(0).add(tempday);
		alldatalistdate.get(0).add(getWeek(tempday));
		
		for(int i=0;i<18;i++)
			alldatalistnum.get(0).add(0);
		
		int tempflavorindex=Integer.valueOf(tempflavor.substring(6));//当天虚拟机的标号
		if(tempflavorindex>=1&&tempflavorindex<=15){//排除异常flavor影响
			alldatalistnum.get(0).set(0, judgeCPU("flavor"+tempflavorindex));
			alldatalistnum.get(0).set(1, judgeMemory("flavor"+tempflavorindex));
			alldatalistnum.get(0).set(tempflavorindex+1, 1);
			alldatalistnum.get(0).set(17, 1);
		}
		//----------------先把第一天的放进去
		for(int i=1;i<readresult.length;i++){
			String thisday=readresult[i].split("\\s+")[2];
			String thisflavor=readresult[i].split("\\s+")[1];
			if(thisday.equals(alldatalistdate.get(alldatalistdate.size()-1).get(0))){//如果是同一天
				int tempflavori=Integer.valueOf(thisflavor.substring(6));//当天这台虚拟机的标号
				if(tempflavori>=1&&tempflavori<=15){
					alldatalistnum.get(alldatalistnum.size()-1).set(0, alldatalistnum.get(alldatalistnum.size()-1).get(0)+judgeCPU("flavor"+tempflavori));
					alldatalistnum.get(alldatalistnum.size()-1).set(1, alldatalistnum.get(alldatalistnum.size()-1).get(1)+judgeMemory("flavor"+tempflavori));
					alldatalistnum.get(alldatalistnum.size()-1).set(tempflavori+1, alldatalistnum.get(alldatalistnum.size()-1).get(tempflavori+1)+1);
					alldatalistnum.get(alldatalistnum.size()-1).set(17, alldatalistnum.get(alldatalistnum.size()-1).get(17)+1);
				}
			}
			else{//如果是不同的一天,那就要重新建立了
				while(dayDifference(thisday, alldatalistdate.get(alldatalistdate.size()-1).get(0))>1){//相互之间日期相差1以上,要把中间空缺的日期补上
					String lastdate=alldatalistdate.get(alldatalistdate.size()-1).get(0);
					alldatalistdate.add(new ArrayList<String>());
					alldatalistdate.get(alldatalistdate.size()-1).add(nextDay(lastdate));
					alldatalistdate.get(alldatalistdate.size()-1).add(getWeek(nextDay(lastdate)));
					alldatalistnum.add(new ArrayList<Integer>());
					for(int z=0;z<18;z++)
						alldatalistnum.get(alldatalistnum.size()-1).add(0);
				}
				alldatalistdate.add(new ArrayList<String>());
				alldatalistnum.add(new ArrayList<Integer>());
				
				alldatalistdate.get(alldatalistdate.size()-1).add(thisday);
				alldatalistdate.get(alldatalistdate.size()-1).add(getWeek(thisday));
				
				for(int j=0;j<18;j++)
					alldatalistnum.get(alldatalistnum.size()-1).add(0);
				
				int tempflavori=Integer.valueOf(thisflavor.substring(6));
				
				if(tempflavori>=1&&tempflavori<=15){//排除异常flavor影响
					alldatalistnum.get(alldatalistnum.size()-1).set(0, judgeCPU("flavor"+tempflavori));
					alldatalistnum.get(alldatalistnum.size()-1).set(1, judgeMemory("flavor"+tempflavori));
					alldatalistnum.get(alldatalistnum.size()-1).set(tempflavori+1, 1);
					alldatalistnum.get(alldatalistnum.size()-1).set(17, 1);
				}
			}
		}
		//这个时候,alldatalistdate和alldatalistnum已经装好了
		for(int k=0;k<alldatalistdate.size();k++){
			System.out.print(alldatalistdate.get(k).get(0)+"  "+alldatalistdate.get(k).get(1)+"  ");
			for(int b=0;b<18;b++)
				System.out.print(alldatalistnum.get(k).get(b)+"  ");
			System.out.println();
		}
//		for(int k=0;k<readresult.length;k++)
//			System.out.println(readresult[k]);
	}
	public static int judgeCPU(String s) {//返回虚拟机CPU的消耗
		if(s.equals("flavor1")||s.equals("flavor2")||s.equals("flavor3"))
			return 1;
		else if(s.equals("flavor4")||s.equals("flavor5")||s.equals("flavor6"))
			return 2;
		else if(s.equals("flavor7")||s.equals("flavor8")||s.equals("flavor9"))
			return 4;
		else if(s.equals("flavor10")||s.equals("flavor11")||s.equals("flavor12"))
			return 8;
		else if(s.equals("flavor13")||s.equals("flavor14")||s.equals("flavor15"))
			return 16;
		else 
			return 0;
	}
	public static int judgeMemory(String s) {//返回虚拟机Memory的消耗,单位GB
		if(s.equals("flavor1"))
			return 1;
		else if(s.equals("flavor2")||s.equals("flavor4"))
			return 2;
		else if(s.equals("flavor3")||s.equals("flavor5"))
			return 4;
		else if(s.equals("flavor6")||s.equals("flavor7")||s.equals("flavor8")||s.equals("flavor10"))
			return 8;
		else if(s.equals("flavor9")||s.equals("flavor11")||s.equals("flavor13"))
			return 16;
		else if(s.equals("flavor12")||s.equals("flavor14"))
			return 32;
		else if(s.equals("flavor15"))
			return 64;
		else
			return 0;
	}
	//返回星期几,中文
	public static String getWeek(String sdate) {  
	    // 再转换为时间  
	     Date date = strToDate(sdate);  
	     Calendar c = Calendar.getInstance();  
	     c.setTime(date);  
	     // int hour=c.get(Calendar.DAY_OF_WEEK);  
	     // hour中存的就是星期几了,其范围 1~7  
	     // 1=星期日 7=星期六,其他类推  
	     return new SimpleDateFormat("EEEE").format(c.getTime());  
	 }  
	public static Date strToDate(String strDate) {  
	     SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd");  
	     ParsePosition pos = new ParsePosition(0);  
	     Date strtodate = formatter.parse(strDate, pos);  
	     return strtodate;  
	 }
	public static long dayDifference(String d1,String d2) throws ParseException{//天数大的是1,天数小的是2
		Date a1 = new SimpleDateFormat("yyyy-MM-dd").parse(d1);
		Date b1 = new SimpleDateFormat("yyyy-MM-dd").parse(d2);
		//获取相减后天数
		long day = (a1.getTime()-b1.getTime())/(24*60*60*1000);
		return day;
	}
	public static String nextDay(String d){//返回这一天的下一天的string表示
		SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
        Date date = sdf.parse(d, new ParsePosition(0));
        Calendar calendar = Calendar.getInstance();
        calendar.setTime(date);
        // add方法中的第二个参数n中,正数表示该日期后n天,负数表示该日期的前n天
        calendar.add(Calendar.DATE, 1);
        Date date1 = calendar.getTime();
        String out = sdf.format(date1);
        return out;
	}
}

最后我们把文件在excel上进行处理,并且用origin画柱形图(不能画散点图,没法看清趋势),得到了以下的数据:

可以发现,这些数据有以下几个特点:

1、对于单个Flavor来说,非常稀疏;

2、时间间隔很大;

3、有一些明显的天数里,虚拟机申请量极高,高到不同寻常。

因此,必须找到一些数据处理办法,把这些异常点去掉。我们尝试过的方法有:

1、如果某天申请量超过了某个阈值,那就把那天的申请量变为那天的前一天和后一天的平均值。如果那天刚好在第一天(最后一天),则变为后一天(前一天)的两倍减去后第二天(前第二天)。这样做我们认为可以去去掉异常点的同时,让改天反映出正常的变化趋势。我们最终使用的是这个方法,而且在所有尝试过的方法中效果是最好的(阈值设为22)

	public static void dealwithlargedatabeyondconstantreplacefrontbackavgnew(double[][] input,double expect) {
		int col=input[0].length;
		for(int i=0;i<row;i++)
			for(int j=0;j<col;j++)
				if(input[i][j]>expect)
					if(i==0)
						input[i][j]=2*input[i+1][j]-input[i+2][j];
					else if(i==row-1)
						input[i][j]=2*input[i-1][j]-input[i-2][j];
					else
						input[i][j]=(input[i-1][j]+input[i+1][j])/2;

2、正态分布法

在正态分布的假设下,可以算出全部序列值的均值和方差,如果某个值分布超过了标准差的三倍,则认为是异常点。

3、截头去尾法

把一系列时间序列的数据排序,把头或尾的某个百分比的数据都认为是异常点。我询问过一位在罗切斯特大学学计算机的同学,他们经常使用这个方法作为数据预处理。

二、数据特征挖掘部分

作为原始的数据,数据特征的挖掘很重要,因为原始数据的特征太少了,只有一个时间和个数信息。这个时候的疑问是,申请是否跟星期几有关?各个虚拟机申请直接是否有关系?CPU和MEM总数有没有影响?尽管我们都尝试过把他们作为数据的特征之一加进去,但是效果仍然不尽理想,最后我们提交的代码,并没有用上除了时间和个数之外的额外信息。例如,如果看当天的CPU和MEM总数,可以发现:


似乎隐隐有一个周期性的变化,大概以一个星期为周期。

1、加入星期几

我们尝试过把星期几的信息变为一个7维向量加入到每天的特征中。

2、是否周末

我们尝试过把是否周末作为特征加进去。

3、当天的CPU或MEM总量

我们尝试过把当天CPU和MEM总量加进去。

上面这些方法,效果都比不上什么都不加好。

三、数据进一步处理部分

为什么要归一化呢?在机器学习中领域中的数据分析之前,通常需要将数据标准化,利用标准化后得数据进行数据分析。不同评价指标往往具有不同的量纲和量纲单位,这样的情况会影响到数据分析的结果,为了消除指标之间的量纲影响,需要进行数据标准化处理,以解决数据指标之间的可比性。原始数据经过数据标准化处理后,各指标处于同一数量级,适合进行综合对比评价。我们用的是最常用的z-score归一化方法。当然还有一个最大最小标准化,效果不如z-score好。

四、预测部分

预测部分就很考验调参了。稍微一点变化都会导致分数有巨大的变动。为了保证结果可重复性,有随机数的地方我们都采用了随机数种子。

1、GRU的权重矩阵。在这里,我们分别用两种不同的方法初始化。一种是均匀随机数,一种是高斯随机数。其实我们还试过高斯随机截头去尾法,也就是对每个生成的高斯随机数,如果原理均值太远,则重新生成。最后使用的是0-1之间的均匀随机数。如果使用高斯随机数,只要稍微对方差做出一点变化,对整体分数变化都很大,各个参数都要重新调整。

    private int inSize;
    private int outSize;
    private int deSize;

    private DoubleMatrix Wxr;
    private DoubleMatrix Whr;
    private DoubleMatrix br;

    private DoubleMatrix Wxz;
    private DoubleMatrix Whz;
    private DoubleMatrix bz;

    private DoubleMatrix Wxh;
    private DoubleMatrix Whh;
    private DoubleMatrix bh;

    private DoubleMatrix Why;
    private DoubleMatrix by;

    public GRU(int inSize, int outSize, MatIniter initer) {
        this.inSize = inSize;
        this.outSize = outSize;

        if (initer.getType() == Type.Uniform) {
            this.Wxr = initer.uniform(inSize, outSize);
            this.Whr = initer.uniform(outSize, outSize);
            this.br = new DoubleMatrix(1, outSize);

            this.Wxz = initer.uniform(inSize, outSize);
            this.Whz = initer.uniform(outSize, outSize);
            this.bz = new DoubleMatrix(1, outSize);

            this.Wxh = initer.uniform(inSize, outSize);
            this.Whh = initer.uniform(outSize, outSize);
            this.bh = new DoubleMatrix(1, outSize);

            this.Why = initer.uniform(outSize, inSize);
            this.by = new DoubleMatrix(1, inSize);
        } else if (initer.getType() == Type.Gaussian) {
            this.Wxr = initer.gaussian(inSize, outSize);
            this.Whr = initer.gaussian(outSize, outSize);
            this.br = new DoubleMatrix(1, outSize);

            this.Wxz = initer.gaussian(inSize, outSize);
            this.Whz = initer.gaussian(outSize, outSize);
            this.bz = new DoubleMatrix(1, outSize);

            this.Wxh = initer.gaussian(inSize, outSize);
            this.Whh = initer.gaussian(outSize, outSize);
            this.bh = new DoubleMatrix(1, outSize);

            this.Why = initer.gaussian(outSize, inSize);
            this.by = new DoubleMatrix(1, inSize);
        }
    }

2、训练传递部分。训练部分,就是考验调参的时刻了。学习速率,每次训练的时间窗口,训练次数,都对分数有巨大影响。而且因为计算时间和计算能力的限制,只能采用随机梯度下降算法来进行训练,这就导致有可能在训练损失函数上有跳变,无法完全收敛。这里,其实我们还尝试过一些自适应学习率下降的方法,例如动量法等,但是其实效果都不明显,最后还是就使用了固定的学习速率。还有一些更高深的训练方法,也没有尝试过。

    public void train(double[][] data, double lr) {
//        java.util.Random random = new java.util.Random();
        java.util.Random random = new java.util.Random(5);
        int slideT = GRUPredict.slideT;
        for (int i = 0; i < 2500; i++) {
/**
 * 学习率处理-自动变小
 * ***/
//            if (i-i/2000*2000==0)
//                lr = lr / 3;

            double error = 0;
            double num = 0;
            double start = System.currentTimeMillis();
            acts = new HashMap<>();
            int index = random.nextInt(data.length - slideT);

            for (int t = 0; t < slideT; t++) {
                double[] dayData = data[index+t];
                DoubleMatrix xt = new DoubleMatrix(1,dayData.length,dayData);
                acts.put("x" + t, xt);

                gru.active(t, acts);

                DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
                acts.put("py" + t, predcitYt);
                DoubleMatrix trueYt = new DoubleMatrix(1,data[index+1+t].length,data[index+1+t]);
                acts.put("y" + t, trueYt);

                    //System.out.print(indexChar.get(predcitYt.argmax()));
                error += LossFunction.getLoss(predcitYt, trueYt);

            }
//            for (int t = 0; t < data.length-1; t++) {打印每次loss function观察收敛情况
//                double[] dayData = data[t];
//                DoubleMatrix xt = new DoubleMatrix(1,dayData.length,dayData);
//                acts.put("x" + t, xt);
//
//                gru.active(t, acts);
//
//                DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
//                acts.put("py" + t, predcitYt);
//                DoubleMatrix trueYt = new DoubleMatrix(1,data[t+1].length,data[t+1]);
//                acts.put("y" + t, trueYt);
//
//                    //System.out.print(indexChar.get(predcitYt.argmax()));
//                error += LossFunction.getLoss(predcitYt, trueYt);
//            }
//
//            System.out.println();

                // bptt
            gru.bptt(acts, slideT-1, lr);
//            gru.bptt(acts, data.length - 2, lr);

//            num += data.length;
            num = slideT;
//            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
            System.out.println(error / num);
//            if(error/num<0.01)
//                return;

        }
    }

3、后向传播

public void bptt(Map<String, DoubleMatrix> acts, int lastT, double lr) {
        for (int t = lastT; t > -1; t--) {
            DoubleMatrix py = acts.get("py" + t);
            DoubleMatrix y = acts.get("y" + t);
            DoubleMatrix deltaY = py.sub(y);
            acts.put("dy" + t, deltaY);

            // cell output errors
            DoubleMatrix h = acts.get("h" + t);
            DoubleMatrix z = acts.get("z" + t);
            DoubleMatrix r = acts.get("r" + t);
            DoubleMatrix gh = acts.get("gh" + t);

            DoubleMatrix deltaH = null;
            if (t == lastT) {
                deltaH = Why.mmul(deltaY.transpose()).transpose();
            } else {
                DoubleMatrix lateDh = acts.get("dh" + (t + 1));
                DoubleMatrix lateDgh = acts.get("dgh" + (t + 1));
                DoubleMatrix lateDr = acts.get("dr" + (t + 1));
                DoubleMatrix lateDz = acts.get("dz" + (t + 1));
                DoubleMatrix lateR = acts.get("r" + (t + 1));
                DoubleMatrix lateZ = acts.get("z" + (t + 1));
                deltaH = Why.mmul(deltaY.transpose()).transpose()
                        .add(Whr.mmul(lateDr.transpose()).transpose())
                        .add(Whz.mmul(lateDz.transpose()).transpose())
                        .add(Whh.mmul(lateDgh.mul(lateR).transpose()).transpose())
                        .add(lateDh.mul(DoubleMatrix.ones(1, lateZ.columns).sub(lateZ)));
            }
            acts.put("dh" + t, deltaH);

            // gh
            DoubleMatrix deltaGh = deltaH.mul(z).mul(deriveTanh(gh));
            acts.put("dgh" + t, deltaGh);

            DoubleMatrix preH = null;
            if (t > 0) {
                preH = acts.get("h" + (t - 1));
            } else {
                preH = DoubleMatrix.zeros(1, h.length);
            }

            // reset gates
            DoubleMatrix deltaR = (Whh.mmul(deltaGh.mul(preH).transpose()).transpose()).mul(deriveExp(r));
            acts.put("dr" + t, deltaR);

            // update gates
            DoubleMatrix deltaZ = deltaH.mul(gh.sub(preH)).mul(deriveExp(z));
            acts.put("dz" + t, deltaZ);
        }
        updateParameters(acts, lastT, lr);
    }

4、更新权重矩阵和偏置

    private void updateParameters(Map<String, DoubleMatrix> acts, int lastT, double lr) {
        DoubleMatrix gWxr = new DoubleMatrix(Wxr.rows, Wxr.columns);
        DoubleMatrix gWhr = new DoubleMatrix(Whr.rows, Whr.columns);
        DoubleMatrix gbr = new DoubleMatrix(br.rows, br.columns);

        DoubleMatrix gWxz = new DoubleMatrix(Wxz.rows, Wxz.columns);
        DoubleMatrix gWhz = new DoubleMatrix(Whz.rows, Whz.columns);
        DoubleMatrix gbz = new DoubleMatrix(bz.rows, bz.columns);

        DoubleMatrix gWxh = new DoubleMatrix(Wxh.rows, Wxh.columns);
        DoubleMatrix gWhh = new DoubleMatrix(Whh.rows, Whh.columns);
        DoubleMatrix gbh = new DoubleMatrix(bh.rows, bh.columns);

        DoubleMatrix gWhy = new DoubleMatrix(Why.rows, Why.columns);
        DoubleMatrix gby = new DoubleMatrix(by.rows, by.columns);

        for (int t = 0; t < lastT + 1; t++) {
            DoubleMatrix x = acts.get("x" + t).transpose();
            gWxr = gWxr.add(x.mmul(acts.get("dr" + t)));
            gWxz = gWxz.add(x.mmul(acts.get("dz" + t)));
            gWxh = gWxh.add(x.mmul(acts.get("dgh" + t)));

            if (t > 0) {
                DoubleMatrix preH = acts.get("h" + (t - 1)).transpose();
                gWhr = gWhr.add(preH.mmul(acts.get("dr" + t)));
                gWhz = gWhz.add(preH.mmul(acts.get("dz" + t)));
                gWhh = gWhh.add(acts.get("r" + t).transpose().mul(preH).mmul(acts.get("dgh" + t)));
            }
            gWhy = gWhy.add(acts.get("h" + t).transpose().mmul(acts.get("dy" + t)));

            gbr = gbr.add(acts.get("dr" + t));
            gbz = gbz.add(acts.get("dz" + t));
            gbh = gbh.add(acts.get("dgh" + t));
            gby = gby.add(acts.get("dy" + t));
        }

        Wxr = Wxr.sub(clip(gWxr.div(lastT)).mul(lr));
        Whr = Whr.sub(clip(gWhr.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
        br = br.sub(clip(gbr.div(lastT)).mul(lr));

        Wxz = Wxz.sub(clip(gWxz.div(lastT)).mul(lr));
        Whz = Whz.sub(clip(gWhz.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
        bz = bz.sub(clip(gbz.div(lastT)).mul(lr));

        Wxh = Wxh.sub(clip(gWxh.div(lastT)).mul(lr));
        Whh = Whh.sub(clip(gWhh.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
        bh = bh.sub(clip(gbh.div(lastT)).mul(lr));

        Why = Why.sub(clip(gWhy.div(lastT)).mul(lr));
        by = by.sub(clip(gby.div(lastT)).mul(lr));
    }

还有一些辅助性质的代码,就不放上去了。

总的来说,这次华为的比赛还是学到了点东西的,算是对神经网络入了个门,明白了大致的原理。



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值