预测部分,我们使用的是GRU,也就是简化版的LSTM。其实根据我们的经验,这次赛事使用神经网络效果并不好,因为神经网络需要大量的数据喂食和训练,而这次比赛提供的训练数据量太少了。以下把我们的代码贴出来,大家一起讨论。代码很大程度上参考了github上的https://github.com/Root-lee/Put_Flavors_SA和https://github.com/lipiji/JRNN,在此非常感谢!
我们认为,作为神经网络训练来说,最主要的是分三部分:第一,数据的预处理。这里的数据预处理指的是,通过一些算法,把一些特定点作为异常点或者噪点抹去(例如一个特别高的点)以及对现有的数据的特征进行深入挖掘,提取出更多有效的特征,从而提高训练的有效性和准确性,以及进行一些后续处理。至少,对于这次比赛来说,这部分非常重要,对预测的准确性有极大的影响。
一、数据预处理的去异常点部分
对于神经网络训练来说,不同的数据通常有很不一样的特征。我们认为,要确定怎么去异常点,首先要通过人眼进行观察,从而有个大致的思路。我们因此把比赛官方给我们测试用的某个训练数据集,自己写了代码,画了图,进行观察。官方的训练数据集是每天的一些虚拟机申请数据,如下:
这里面:1、头尾的序号和具体时间是没用的,只有虚拟机名字和日期是有用数据。而虚拟机名字又有些是超出了我们需要预测的范围的没用数据,要剔除。另外,日期也不一定连续,需要用java的日期判断函数把不连续的日期补上。
//这份文件,把所有需要的数据都提炼出来
class FileUtil//官方提供的那个读取文件的类,我把它变成了静态内部类
{
public static String[] read(final String filePath, final Integer spec)
{
File file = new File(filePath);
// 当文件不存在或者不可读时
if ((!isFileExists(file)) || (!file.canRead()))
{
System.out.println("file [" + filePath + "] is not exist or cannot read!!!");
return null;
}
List<String> lines = new LinkedList<String>();
BufferedReader br = null;
FileReader fb = null;
try
{
fb = new FileReader(file);
br = new BufferedReader(fb);
String str = null;
int index = 0;
while (((spec == null) || index++ < spec) && (str = br.readLine()) != null)
{
lines.add(str);
}
}
catch (IOException e)
{
e.printStackTrace();
}
finally
{
closeQuietly(br);
closeQuietly(fb);
}
return lines.toArray(new String[lines.size()]);
}
public static int write(final String filePath, final String[] contents, final boolean append)
{
File file = new File(filePath);
if (contents == null)
{
System.out.println("file [" + filePath + "] invalid!!!");
return 0;
}
if (isFileExists(file) && (!file.canRead()))
{
return 0;
}
FileWriter fw = null;
BufferedWriter bw = null;
try
{
if (!isFileExists(file))
{
file.createNewFile();
}
fw = new FileWriter(file, append);
bw = new BufferedWriter(fw);
for (String content : contents)
{
if (content == null)
{
continue;
}
bw.write(content);
bw.newLine();
}
}
catch (IOException e)
{
e.printStackTrace();
return 0;
}
finally
{
closeQuietly(bw);
closeQuietly(fw);
}
return 1;
}
private static void closeQuietly(Closeable closeable)
{
try
{
if (closeable != null)
{
closeable.close();
}
}
catch (IOException e)
{
}
}
private static boolean isFileExists(final File file)
{
if (file.exists() && file.isFile())
{
return true;
}
return false;
}
}
public class datedataoutput {
public static void main(String[] args) throws ParseException {
// TODO 自动生成的方法存根
String filepath="F:\\我的亿方云同步\\FangCloudV2\\个人文件\\华为比赛相关\\20180322datedatasumury\\originaldata.txt";
String[] readresult=FileUtil.read(filepath, null);//日期是从前往后的
ArrayList<ArrayList<String>> alldatalistdate=new ArrayList<ArrayList<String>>();
ArrayList<ArrayList<Integer>> alldatalistnum=new ArrayList<ArrayList<Integer>>();
//----------------先把第一天的放进去
String tempday=readresult[0].split("\\s+")[2];
String tempflavor=readresult[0].split("\\s+")[1];
alldatalistdate.add(new ArrayList<String>());
alldatalistnum.add(new ArrayList<Integer>());
alldatalistdate.get(0).add(tempday);
alldatalistdate.get(0).add(getWeek(tempday));
for(int i=0;i<18;i++)
alldatalistnum.get(0).add(0);
int tempflavorindex=Integer.valueOf(tempflavor.substring(6));//当天虚拟机的标号
if(tempflavorindex>=1&&tempflavorindex<=15){//排除异常flavor影响
alldatalistnum.get(0).set(0, judgeCPU("flavor"+tempflavorindex));
alldatalistnum.get(0).set(1, judgeMemory("flavor"+tempflavorindex));
alldatalistnum.get(0).set(tempflavorindex+1, 1);
alldatalistnum.get(0).set(17, 1);
}
//----------------先把第一天的放进去
for(int i=1;i<readresult.length;i++){
String thisday=readresult[i].split("\\s+")[2];
String thisflavor=readresult[i].split("\\s+")[1];
if(thisday.equals(alldatalistdate.get(alldatalistdate.size()-1).get(0))){//如果是同一天
int tempflavori=Integer.valueOf(thisflavor.substring(6));//当天这台虚拟机的标号
if(tempflavori>=1&&tempflavori<=15){
alldatalistnum.get(alldatalistnum.size()-1).set(0, alldatalistnum.get(alldatalistnum.size()-1).get(0)+judgeCPU("flavor"+tempflavori));
alldatalistnum.get(alldatalistnum.size()-1).set(1, alldatalistnum.get(alldatalistnum.size()-1).get(1)+judgeMemory("flavor"+tempflavori));
alldatalistnum.get(alldatalistnum.size()-1).set(tempflavori+1, alldatalistnum.get(alldatalistnum.size()-1).get(tempflavori+1)+1);
alldatalistnum.get(alldatalistnum.size()-1).set(17, alldatalistnum.get(alldatalistnum.size()-1).get(17)+1);
}
}
else{//如果是不同的一天,那就要重新建立了
while(dayDifference(thisday, alldatalistdate.get(alldatalistdate.size()-1).get(0))>1){//相互之间日期相差1以上,要把中间空缺的日期补上
String lastdate=alldatalistdate.get(alldatalistdate.size()-1).get(0);
alldatalistdate.add(new ArrayList<String>());
alldatalistdate.get(alldatalistdate.size()-1).add(nextDay(lastdate));
alldatalistdate.get(alldatalistdate.size()-1).add(getWeek(nextDay(lastdate)));
alldatalistnum.add(new ArrayList<Integer>());
for(int z=0;z<18;z++)
alldatalistnum.get(alldatalistnum.size()-1).add(0);
}
alldatalistdate.add(new ArrayList<String>());
alldatalistnum.add(new ArrayList<Integer>());
alldatalistdate.get(alldatalistdate.size()-1).add(thisday);
alldatalistdate.get(alldatalistdate.size()-1).add(getWeek(thisday));
for(int j=0;j<18;j++)
alldatalistnum.get(alldatalistnum.size()-1).add(0);
int tempflavori=Integer.valueOf(thisflavor.substring(6));
if(tempflavori>=1&&tempflavori<=15){//排除异常flavor影响
alldatalistnum.get(alldatalistnum.size()-1).set(0, judgeCPU("flavor"+tempflavori));
alldatalistnum.get(alldatalistnum.size()-1).set(1, judgeMemory("flavor"+tempflavori));
alldatalistnum.get(alldatalistnum.size()-1).set(tempflavori+1, 1);
alldatalistnum.get(alldatalistnum.size()-1).set(17, 1);
}
}
}
//这个时候,alldatalistdate和alldatalistnum已经装好了
for(int k=0;k<alldatalistdate.size();k++){
System.out.print(alldatalistdate.get(k).get(0)+" "+alldatalistdate.get(k).get(1)+" ");
for(int b=0;b<18;b++)
System.out.print(alldatalistnum.get(k).get(b)+" ");
System.out.println();
}
// for(int k=0;k<readresult.length;k++)
// System.out.println(readresult[k]);
}
public static int judgeCPU(String s) {//返回虚拟机CPU的消耗
if(s.equals("flavor1")||s.equals("flavor2")||s.equals("flavor3"))
return 1;
else if(s.equals("flavor4")||s.equals("flavor5")||s.equals("flavor6"))
return 2;
else if(s.equals("flavor7")||s.equals("flavor8")||s.equals("flavor9"))
return 4;
else if(s.equals("flavor10")||s.equals("flavor11")||s.equals("flavor12"))
return 8;
else if(s.equals("flavor13")||s.equals("flavor14")||s.equals("flavor15"))
return 16;
else
return 0;
}
public static int judgeMemory(String s) {//返回虚拟机Memory的消耗,单位GB
if(s.equals("flavor1"))
return 1;
else if(s.equals("flavor2")||s.equals("flavor4"))
return 2;
else if(s.equals("flavor3")||s.equals("flavor5"))
return 4;
else if(s.equals("flavor6")||s.equals("flavor7")||s.equals("flavor8")||s.equals("flavor10"))
return 8;
else if(s.equals("flavor9")||s.equals("flavor11")||s.equals("flavor13"))
return 16;
else if(s.equals("flavor12")||s.equals("flavor14"))
return 32;
else if(s.equals("flavor15"))
return 64;
else
return 0;
}
//返回星期几,中文
public static String getWeek(String sdate) {
// 再转换为时间
Date date = strToDate(sdate);
Calendar c = Calendar.getInstance();
c.setTime(date);
// int hour=c.get(Calendar.DAY_OF_WEEK);
// hour中存的就是星期几了,其范围 1~7
// 1=星期日 7=星期六,其他类推
return new SimpleDateFormat("EEEE").format(c.getTime());
}
public static Date strToDate(String strDate) {
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd");
ParsePosition pos = new ParsePosition(0);
Date strtodate = formatter.parse(strDate, pos);
return strtodate;
}
public static long dayDifference(String d1,String d2) throws ParseException{//天数大的是1,天数小的是2
Date a1 = new SimpleDateFormat("yyyy-MM-dd").parse(d1);
Date b1 = new SimpleDateFormat("yyyy-MM-dd").parse(d2);
//获取相减后天数
long day = (a1.getTime()-b1.getTime())/(24*60*60*1000);
return day;
}
public static String nextDay(String d){//返回这一天的下一天的string表示
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
Date date = sdf.parse(d, new ParsePosition(0));
Calendar calendar = Calendar.getInstance();
calendar.setTime(date);
// add方法中的第二个参数n中,正数表示该日期后n天,负数表示该日期的前n天
calendar.add(Calendar.DATE, 1);
Date date1 = calendar.getTime();
String out = sdf.format(date1);
return out;
}
}
最后我们把文件在excel上进行处理,并且用origin画柱形图(不能画散点图,没法看清趋势),得到了以下的数据:
可以发现,这些数据有以下几个特点:
1、对于单个Flavor来说,非常稀疏;
2、时间间隔很大;
3、有一些明显的天数里,虚拟机申请量极高,高到不同寻常。
因此,必须找到一些数据处理办法,把这些异常点去掉。我们尝试过的方法有:
1、如果某天申请量超过了某个阈值,那就把那天的申请量变为那天的前一天和后一天的平均值。如果那天刚好在第一天(最后一天),则变为后一天(前一天)的两倍减去后第二天(前第二天)。这样做我们认为可以去去掉异常点的同时,让改天反映出正常的变化趋势。我们最终使用的是这个方法,而且在所有尝试过的方法中效果是最好的(阈值设为22)
public static void dealwithlargedatabeyondconstantreplacefrontbackavgnew(double[][] input,double expect) {
int col=input[0].length;
for(int i=0;i<row;i++)
for(int j=0;j<col;j++)
if(input[i][j]>expect)
if(i==0)
input[i][j]=2*input[i+1][j]-input[i+2][j];
else if(i==row-1)
input[i][j]=2*input[i-1][j]-input[i-2][j];
else
input[i][j]=(input[i-1][j]+input[i+1][j])/2;
2、正态分布法
在正态分布的假设下,可以算出全部序列值的均值和方差,如果某个值分布超过了标准差的三倍,则认为是异常点。
3、截头去尾法
把一系列时间序列的数据排序,把头或尾的某个百分比的数据都认为是异常点。我询问过一位在罗切斯特大学学计算机的同学,他们经常使用这个方法作为数据预处理。
二、数据特征挖掘部分
作为原始的数据,数据特征的挖掘很重要,因为原始数据的特征太少了,只有一个时间和个数信息。这个时候的疑问是,申请是否跟星期几有关?各个虚拟机申请直接是否有关系?CPU和MEM总数有没有影响?尽管我们都尝试过把他们作为数据的特征之一加进去,但是效果仍然不尽理想,最后我们提交的代码,并没有用上除了时间和个数之外的额外信息。例如,如果看当天的CPU和MEM总数,可以发现:
似乎隐隐有一个周期性的变化,大概以一个星期为周期。
1、加入星期几
我们尝试过把星期几的信息变为一个7维向量加入到每天的特征中。
2、是否周末
我们尝试过把是否周末作为特征加进去。
3、当天的CPU或MEM总量
我们尝试过把当天CPU和MEM总量加进去。
上面这些方法,效果都比不上什么都不加好。
三、数据进一步处理部分
为什么要归一化呢?在机器学习中领域中的数据分析之前,通常需要将数据标准化,利用标准化后得数据进行数据分析。不同评价指标往往具有不同的量纲和量纲单位,这样的情况会影响到数据分析的结果,为了消除指标之间的量纲影响,需要进行数据标准化处理,以解决数据指标之间的可比性。原始数据经过数据标准化处理后,各指标处于同一数量级,适合进行综合对比评价。我们用的是最常用的z-score归一化方法。当然还有一个最大最小标准化,效果不如z-score好。
四、预测部分
预测部分就很考验调参了。稍微一点变化都会导致分数有巨大的变动。为了保证结果可重复性,有随机数的地方我们都采用了随机数种子。
1、GRU的权重矩阵。在这里,我们分别用两种不同的方法初始化。一种是均匀随机数,一种是高斯随机数。其实我们还试过高斯随机截头去尾法,也就是对每个生成的高斯随机数,如果原理均值太远,则重新生成。最后使用的是0-1之间的均匀随机数。如果使用高斯随机数,只要稍微对方差做出一点变化,对整体分数变化都很大,各个参数都要重新调整。
private int inSize;
private int outSize;
private int deSize;
private DoubleMatrix Wxr;
private DoubleMatrix Whr;
private DoubleMatrix br;
private DoubleMatrix Wxz;
private DoubleMatrix Whz;
private DoubleMatrix bz;
private DoubleMatrix Wxh;
private DoubleMatrix Whh;
private DoubleMatrix bh;
private DoubleMatrix Why;
private DoubleMatrix by;
public GRU(int inSize, int outSize, MatIniter initer) {
this.inSize = inSize;
this.outSize = outSize;
if (initer.getType() == Type.Uniform) {
this.Wxr = initer.uniform(inSize, outSize);
this.Whr = initer.uniform(outSize, outSize);
this.br = new DoubleMatrix(1, outSize);
this.Wxz = initer.uniform(inSize, outSize);
this.Whz = initer.uniform(outSize, outSize);
this.bz = new DoubleMatrix(1, outSize);
this.Wxh = initer.uniform(inSize, outSize);
this.Whh = initer.uniform(outSize, outSize);
this.bh = new DoubleMatrix(1, outSize);
this.Why = initer.uniform(outSize, inSize);
this.by = new DoubleMatrix(1, inSize);
} else if (initer.getType() == Type.Gaussian) {
this.Wxr = initer.gaussian(inSize, outSize);
this.Whr = initer.gaussian(outSize, outSize);
this.br = new DoubleMatrix(1, outSize);
this.Wxz = initer.gaussian(inSize, outSize);
this.Whz = initer.gaussian(outSize, outSize);
this.bz = new DoubleMatrix(1, outSize);
this.Wxh = initer.gaussian(inSize, outSize);
this.Whh = initer.gaussian(outSize, outSize);
this.bh = new DoubleMatrix(1, outSize);
this.Why = initer.gaussian(outSize, inSize);
this.by = new DoubleMatrix(1, inSize);
}
}
2、训练传递部分。训练部分,就是考验调参的时刻了。学习速率,每次训练的时间窗口,训练次数,都对分数有巨大影响。而且因为计算时间和计算能力的限制,只能采用随机梯度下降算法来进行训练,这就导致有可能在训练损失函数上有跳变,无法完全收敛。这里,其实我们还尝试过一些自适应学习率下降的方法,例如动量法等,但是其实效果都不明显,最后还是就使用了固定的学习速率。还有一些更高深的训练方法,也没有尝试过。
public void train(double[][] data, double lr) {
// java.util.Random random = new java.util.Random();
java.util.Random random = new java.util.Random(5);
int slideT = GRUPredict.slideT;
for (int i = 0; i < 2500; i++) {
/**
* 学习率处理-自动变小
* ***/
// if (i-i/2000*2000==0)
// lr = lr / 3;
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
acts = new HashMap<>();
int index = random.nextInt(data.length - slideT);
for (int t = 0; t < slideT; t++) {
double[] dayData = data[index+t];
DoubleMatrix xt = new DoubleMatrix(1,dayData.length,dayData);
acts.put("x" + t, xt);
gru.active(t, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
acts.put("py" + t, predcitYt);
DoubleMatrix trueYt = new DoubleMatrix(1,data[index+1+t].length,data[index+1+t]);
acts.put("y" + t, trueYt);
//System.out.print(indexChar.get(predcitYt.argmax()));
error += LossFunction.getLoss(predcitYt, trueYt);
}
// for (int t = 0; t < data.length-1; t++) {打印每次loss function观察收敛情况
// double[] dayData = data[t];
// DoubleMatrix xt = new DoubleMatrix(1,dayData.length,dayData);
// acts.put("x" + t, xt);
//
// gru.active(t, acts);
//
// DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
// acts.put("py" + t, predcitYt);
// DoubleMatrix trueYt = new DoubleMatrix(1,data[t+1].length,data[t+1]);
// acts.put("y" + t, trueYt);
//
// //System.out.print(indexChar.get(predcitYt.argmax()));
// error += LossFunction.getLoss(predcitYt, trueYt);
// }
//
// System.out.println();
// bptt
gru.bptt(acts, slideT-1, lr);
// gru.bptt(acts, data.length - 2, lr);
// num += data.length;
num = slideT;
// System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
System.out.println(error / num);
// if(error/num<0.01)
// return;
}
}
3、后向传播
public void bptt(Map<String, DoubleMatrix> acts, int lastT, double lr) {
for (int t = lastT; t > -1; t--) {
DoubleMatrix py = acts.get("py" + t);
DoubleMatrix y = acts.get("y" + t);
DoubleMatrix deltaY = py.sub(y);
acts.put("dy" + t, deltaY);
// cell output errors
DoubleMatrix h = acts.get("h" + t);
DoubleMatrix z = acts.get("z" + t);
DoubleMatrix r = acts.get("r" + t);
DoubleMatrix gh = acts.get("gh" + t);
DoubleMatrix deltaH = null;
if (t == lastT) {
deltaH = Why.mmul(deltaY.transpose()).transpose();
} else {
DoubleMatrix lateDh = acts.get("dh" + (t + 1));
DoubleMatrix lateDgh = acts.get("dgh" + (t + 1));
DoubleMatrix lateDr = acts.get("dr" + (t + 1));
DoubleMatrix lateDz = acts.get("dz" + (t + 1));
DoubleMatrix lateR = acts.get("r" + (t + 1));
DoubleMatrix lateZ = acts.get("z" + (t + 1));
deltaH = Why.mmul(deltaY.transpose()).transpose()
.add(Whr.mmul(lateDr.transpose()).transpose())
.add(Whz.mmul(lateDz.transpose()).transpose())
.add(Whh.mmul(lateDgh.mul(lateR).transpose()).transpose())
.add(lateDh.mul(DoubleMatrix.ones(1, lateZ.columns).sub(lateZ)));
}
acts.put("dh" + t, deltaH);
// gh
DoubleMatrix deltaGh = deltaH.mul(z).mul(deriveTanh(gh));
acts.put("dgh" + t, deltaGh);
DoubleMatrix preH = null;
if (t > 0) {
preH = acts.get("h" + (t - 1));
} else {
preH = DoubleMatrix.zeros(1, h.length);
}
// reset gates
DoubleMatrix deltaR = (Whh.mmul(deltaGh.mul(preH).transpose()).transpose()).mul(deriveExp(r));
acts.put("dr" + t, deltaR);
// update gates
DoubleMatrix deltaZ = deltaH.mul(gh.sub(preH)).mul(deriveExp(z));
acts.put("dz" + t, deltaZ);
}
updateParameters(acts, lastT, lr);
}
4、更新权重矩阵和偏置
private void updateParameters(Map<String, DoubleMatrix> acts, int lastT, double lr) {
DoubleMatrix gWxr = new DoubleMatrix(Wxr.rows, Wxr.columns);
DoubleMatrix gWhr = new DoubleMatrix(Whr.rows, Whr.columns);
DoubleMatrix gbr = new DoubleMatrix(br.rows, br.columns);
DoubleMatrix gWxz = new DoubleMatrix(Wxz.rows, Wxz.columns);
DoubleMatrix gWhz = new DoubleMatrix(Whz.rows, Whz.columns);
DoubleMatrix gbz = new DoubleMatrix(bz.rows, bz.columns);
DoubleMatrix gWxh = new DoubleMatrix(Wxh.rows, Wxh.columns);
DoubleMatrix gWhh = new DoubleMatrix(Whh.rows, Whh.columns);
DoubleMatrix gbh = new DoubleMatrix(bh.rows, bh.columns);
DoubleMatrix gWhy = new DoubleMatrix(Why.rows, Why.columns);
DoubleMatrix gby = new DoubleMatrix(by.rows, by.columns);
for (int t = 0; t < lastT + 1; t++) {
DoubleMatrix x = acts.get("x" + t).transpose();
gWxr = gWxr.add(x.mmul(acts.get("dr" + t)));
gWxz = gWxz.add(x.mmul(acts.get("dz" + t)));
gWxh = gWxh.add(x.mmul(acts.get("dgh" + t)));
if (t > 0) {
DoubleMatrix preH = acts.get("h" + (t - 1)).transpose();
gWhr = gWhr.add(preH.mmul(acts.get("dr" + t)));
gWhz = gWhz.add(preH.mmul(acts.get("dz" + t)));
gWhh = gWhh.add(acts.get("r" + t).transpose().mul(preH).mmul(acts.get("dgh" + t)));
}
gWhy = gWhy.add(acts.get("h" + t).transpose().mmul(acts.get("dy" + t)));
gbr = gbr.add(acts.get("dr" + t));
gbz = gbz.add(acts.get("dz" + t));
gbh = gbh.add(acts.get("dgh" + t));
gby = gby.add(acts.get("dy" + t));
}
Wxr = Wxr.sub(clip(gWxr.div(lastT)).mul(lr));
Whr = Whr.sub(clip(gWhr.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
br = br.sub(clip(gbr.div(lastT)).mul(lr));
Wxz = Wxz.sub(clip(gWxz.div(lastT)).mul(lr));
Whz = Whz.sub(clip(gWhz.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
bz = bz.sub(clip(gbz.div(lastT)).mul(lr));
Wxh = Wxh.sub(clip(gWxh.div(lastT)).mul(lr));
Whh = Whh.sub(clip(gWhh.div(lastT < 2 ? 1 : (lastT - 1))).mul(lr));
bh = bh.sub(clip(gbh.div(lastT)).mul(lr));
Why = Why.sub(clip(gWhy.div(lastT)).mul(lr));
by = by.sub(clip(gby.div(lastT)).mul(lr));
}
还有一些辅助性质的代码,就不放上去了。
总的来说,这次华为的比赛还是学到了点东西的,算是对神经网络入了个门,明白了大致的原理。