Truncated Gradient --截断梯度
简介
最近接触了大规模机器学习,在进行模型训练的时候采用的是广义线性模型,由于超高维度(十亿级别)导致训练的模型最后超级大,为了上线模型服务,最后的模型不能太大,需要进行模型的剪枝,于是就涉及到了梯度截断,用以减少模型的最终的权重的数量。同时梯度截断也可以减少不重要特征,凸显重要的特征在模型的影响,此外稀疏化的模型在参数更新过程中更具优势。
简单截断法
简单粗暴的方法,设置一个固定的阈值,当某个w小于阈值的时候,直接赋值为0。
这里有一个窗口的概念,参数k为窗口,表示采用截断的最小步长,也就是说截断不是每次都会触发。训练过程中每个batch用参数i表示,(每个可能包含多个样本,也可能包含一个,在线学习时候batchsize一般为1,但是受限于性能,实际应用中batchsize一般大于1)。
当i/k不是整数时候,不触发截断,梯度更新方式和如下:
W ( t + 1 ) = W ( t ) − η ( t ) G ( t ) W _ { ( t + 1 ) } = W _ { ( t ) } - \eta ^ { ( t ) } G ^ { ( t ) } W(t+1)=W(t)−η(t)G(t)
其中, G ( t ) G ^ { ( t ) } G(t)为第t次更新中损失函数的梯度, η ( t ) \eta ^ { ( t ) } η(t)为学习率。
当i/k为整数时候,梯度更新方式和如下:
W ( t + 1 ) = T 0 ( W ( t ) − η ( t ) G ( t ) , θ ) W ^ { ( t + 1 ) } = T _ { 0 } \left( W ^ { ( t ) } - \eta ^ { ( t ) } G ^ { ( t ) } , \theta \right) W(t+1)=T0(W(t)−η(t)G