无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正.
learning_rate , weight_decay , momentum这三个参数的含义. 并附上demo.
我们会使用一个例子来说明一下:
比如我们有一堆数据,我们只知道这对数据是从一个黑盒中得到的,我们现在要寻找到那个具体的函数f(x),我们定义为目标函数T.
我们现在假定有存在这个函数并且这个函数为:
我们现在要使用这对数据来训练目标函数. 我们可以设想如果存在一个这个函数,必定满足{x,y}所有的关系,也就是说:
那么最理想的情况下 : ,那么我们不妨定义这样一个优化目标函数:
对于这堆数据,我们认为当Loss(W)对于所有的pair{x,y}都满足 Loss(W)趋近于或者等于0时,我们认为我们找到这个理想的目标函数T. 也就是此时 .
以上,我们发现寻找的目标函数的问题,已经成功的转移为求解:
也就是Loss 越小, f(x)越接近我们寻找的目标函数T.
那么说了这么多,这个和我们说的学习率learning_rate有什么关系呢?
既然我们知道了我们当前的f(x)和目标函数的T的误差,那么我们可以将这个误差转移到每一个参数上,也就是变成每一个参数w和目标函数T的参数w_t的误差. 然后我们就以一定的幅度stride来缩小和真实值的距离,我们称这个stride为学习率learning_rate 而且我们就是这么做的.
我们用公式表述就是:
我们的误差(损失)Loss:
我们这一个凸函数. 我们先对这个函数进行各个分量求偏导.
对于w0的偏导数:
那么对于分量w0承担的误差为:
并且这个误差带方向.
那么我们需要使我们当前的w0更加接近目标函数的T的w0_t参数.我们需要做运算:
(梯度下降算法)
来更新wo的值. 同理其他参数w,而这个学习率就是来控制我们每次靠近真实值的幅度,为什么要这么做呢?
因为我们表述的误差只是一种空间表述形式我们可以使用均方差也可以使用绝对值,还可以使用对数,以及交叉熵等等,所以只能大致的反映,并不精确,就想我们问路一样,别人告诉我们直走五分钟,有的人走的快,有的人走的慢,所以如果走的快的话,当再次问路的时候,就会发现走多了,而折回来,这就是我们训练过程中的loss曲线震荡严重的原因之一. 所以学习率要设置在合理的大小.
好了说了这么多,这是学习率. 那么什么是权重衰减weight_decay呢? 有什么作用呢?
我们接着看上面的这个Loss(w),我们发现如果参数过多的话,对于高位的w3,我们对其求偏导:
我们发现w3开始大于1的时候,w3会调节的很快,幅度很大,从而使得特征x3变为异常敏感.从而出现过拟合(overfitting).
这个时候,我们需要约束一下w2,w3等高阶参数的大小,于是我们对Loss增加一个惩罚项,使得Loss的正反方向,不应该只由f(x) -y 决定,而还应该加上一个;于是Loss变成了:
我们继续对Loss求解偏导数:
对wo求偏导:
对w3求偏导:
我们发现当x3值过大时,会改变Loss的导数的方向.而来抑制w2,w3等高阶函数的继续增长. 然而这样抑制并不是很灵活,所以我们在前面加入一个系数,这个系数在数学上称之为拉格朗日乘子系数,也就是我们用到的weight_decay. 这样我们可以通过调节weight_decay系数,来调节w3,w2等高阶的增长程度。加入weight_decay后的公式:
从公式可以看出 ,weight_decay越大,抑制越大,w2,w3等系数越小,weight_decay越小,抑制越小,w2,w3等系数越大
那么冲量momentum又是啥?
我们在使用梯度下降法,来调整w时公式是这样的:
我们每一次都是计算当前的梯度:
这样会发现对于那些梯度比较小的地方,参数w更新的幅度比较小,训练变得漫长,或者收敛慢.有时候遇到非最优的凸点,会出现冲不出去的现象.
而冲量加进来是一种快速效应.借助上一次的势能来和当前的梯度来调节当前的参数w.
公式表达为:
这样可以有效的避免掉入凸点无法冲出来,而且收敛速度也快很多.
附上demo: 使用mxnet编码.
1 // 2 // Created by xijun1 on 2017/12/14. 3 // 4 5 #include <iostream> 6 #include <vector> 7 #include <string> 8 #include <mxnet/mxnet-cpp/MxNetCpp.h> 9 #include <mxnet/mxnet-cpp/op.h> 10 11 namespace mlp{ 12 class MlpNet{ 13 public: 14 static mx_float OutputAccuracy(mx_float* pred, mx_float* target) { 15 int right = 0; 16 for (int i = 0; i < 128; ++i) { 17 float mx_p = pred[i * 10 + 0]; 18 float p_y = 0; 19 for (int j = 0; j < 10; ++j) { 20 if (pred[i * 10 + j] > mx_p) { 21 mx_p = pred[i * 10 + j]; 22 p_y = j; 23 } 24 } 25 if (p_y == target[i]) right++; 26 } 27 return right / 128.0; 28 } 29 30 static void net(){ 31 using mxnet::cpp::Symbol; 32 using mxnet::cpp::NDArray; 33 34 Symbol x = Symbol::Variable("X"); 35 Symbol y = Symbol::Variable("label"); 36 37 std::vector<std::int32_t> shapes({512 , 10}); 38 //定义一个两层的网络. wx + b 39 Symbol weight_0 = Symbol::Variable("weight_0"); 40 Symbol biases_0 = Symbol::Variable("biases_0"); 41 42 Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0 43 ,512); 44 45 Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky); 46 47 Symbol weight_1 = Symbol::Variable("weight_1"); 48 Symbol biases_1 = Symbol::Variable("biases_1"); 49 Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10); 50 Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky); 51 Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y); //目标函数,loss函数 52 mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0); 53 54 //定义输入数据 55 std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;}); 56 std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;}); 57 58 //初始化数据 59 for(int i=0 ; i<128 ; i++){ 60 for(int j=0;j<28 ; j++){ 61 //定义x 62 aptr_x.get()[i*28+j]= i % 10 +0.1f; 63 } 64 65 //定义y 66 aptr_y.get()[i]= i % 10; 67 } 68 std::map<std::string, mxnet::cpp::NDArray> args_map; 69 //导入数据 70 NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false); 71 NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false); 72 //将数据转换到NDArray中 73 arr_x.SyncCopyFromCPU(aptr_x.get(),128*28); 74 arr_x.WaitToRead(); 75 76 arr_y.SyncCopyFromCPU(aptr_y.get(),128); 77 arr_y.WaitToRead(); 78 79 args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ; 80 args_map["label"]=arr_y.Slice(0,128).Copy(ctx); 81 NDArray::WaitAll(); 82 //绑定网络 83 mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map); 84 //选择优化器 85 mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd"); 86 mx_float learning_rate = 0.0001; //学习率 87 mx_float weight_decay = 1e-4; //权重 88 opt->SetParam("momentum", 0.9) 89 ->SetParam("lr", learning_rate) 90 ->SetParam("wd", weight_decay); 91 //定义各个层参数的数组 92 NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false); 93 NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false); 94 NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false); 95 NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false); 96 97 //初始化权重参数 98 arr_w_0 = 0.01f; 99 arr_b_1 = 0.01f; 100 arr_w_1 = 0.01f; 101 arr_b_1 = 0.01f; 102 103 //初始化参数 104 executor->arg_dict()["weight_0"]=arr_w_0; 105 executor->arg_dict()["biases_0"]=arr_b_0; 106 executor->arg_dict()["weight_1"]=arr_w_1; 107 executor->arg_dict()["biases_1"]=arr_b_1; 108 109 mxnet::cpp::NDArray::WaitAll(); 110 //训练 111 std::cout<<" Training "<<std::endl; 112 113 int max_iters = 20000; //最大迭代次数 114 //获取训练网络的参数列表 115 std::vector<std::string> args_name = pred.ListArguments(); 116 for (int iter = 0; iter < max_iters ; ++iter) { 117 executor->Forward(true); 118 executor->Backward(); 119 120 if(iter % 100 == 0){ 121 std::vector<NDArray> & out = executor->outputs; 122 std::shared_ptr<mx_float> tp_x( new mx_float[128*28] , 123 [](mx_float * tp_x){ delete [] tp_x ;}); 124 out[0].SyncCopyToCPU(tp_x.get(),128*10); 125 NDArray::WaitAll(); 126 std::cout<<"epoch "<<iter<<" "<<"Accuracy: "<< OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl; 127 } 128 //args_name. 129 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){ 130 //执行 131 if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label") 132 continue; 133 134 opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]); 135 } 136 NDArray::WaitAll(); 137 138 } 139 140 141 } 142 }; 143 } 144 145 int main(int argc , char * argv[]){ 146 mlp::MlpNet::net(); 147 MXNotifyShutdown(); 148 return EXIT_SUCCESS; 149 }
结果:
Training epoch 0 Accuracy: 0.09375 epoch 100 Accuracy: 0.304688 epoch 200 Accuracy: 0.195312 epoch 300 Accuracy: 0.203125 epoch 400 Accuracy: 0.304688 epoch 500 Accuracy: 0.296875 epoch 600 Accuracy: 0.304688 epoch 700 Accuracy: 0.304688 epoch 800 Accuracy: 0.398438 epoch 900 Accuracy: 0.5 epoch 1000 Accuracy: 0.5 epoch 1100 Accuracy: 0.40625 epoch 1200 Accuracy: 0.5 epoch 1300 Accuracy: 0.398438 epoch 1400 Accuracy: 0.40625 epoch 1500 Accuracy: 0.703125 epoch 1600 Accuracy: 0.609375 epoch 1700 Accuracy: 0.507812 epoch 1800 Accuracy: 0.703125 epoch 1900 Accuracy: 0.703125 epoch 2000 Accuracy: 0.804688 epoch 2100 Accuracy: 0.703125 epoch 2200 Accuracy: 0.804688 epoch 2300 Accuracy: 0.804688 epoch 2400 Accuracy: 0.804688 epoch 2500 Accuracy: 0.90625 epoch 2600 Accuracy: 0.90625 epoch 2700 Accuracy: 0.90625 epoch 2800 Accuracy: 1 epoch 2900 Accuracy: 1