集成学习中的Stacking
最近在调试自己的脉冲神经网络模型的时候,无意间读到了一篇文章,作者首先建立了十个脉冲模型
S
n
n
1
,
S
n
n
2
,
.
.
S
n
n
10
Snn_{1},Snn_{2},..Snn_{10}
Snn1,Snn2,..Snn10,每个网络的结构相同但初始权值不同并分别在原数据集随机抽样生成的子训练集上训练收敛,在训练完成后进一步利用了集成学习中Stacking的方法,将10个模型合并起来从而实现对数据集的分类任务。
个人感觉在自己的现有模型中也可以尝试一下集成学习的思想,所以学习一下Stacking并记录下来。
一、集成学习是什么?
我最早了解到集成学习来自于一位朋友的吐槽,他觉得集成学习就是耍赖皮,我感觉说的挺形象的,一个小伙子打不过马老师,两个年轻人搞偷袭就可以。咳咳,Stacking的思想就是“三个臭皮匠顶个诸葛亮”,面对较难的问题多个学习器共同学习,每个学习器都给出一个解决方法,最后进行投票决策。
集成方法是将几种机器学习技术组合成一个预测模型的元算法,以达到减小方差(bagging)、偏差(boosting)或改进预测(stacking)的效果。
集合方法可分为两类:
- 序列集成方法,其中参与训练的基础学习器按照顺序生成(例如
AdaBoost)。序列方法的原理是利用基础学习器之间的依赖关系。通过对之前训练中错误标记的样本赋值较高的权重,可以提高整体的预测效果。 - 并行集成方法,其中参与训练的基础学习器并行生成(例如 Random
Forest)。并行方法的原理是利用基础学习器之间的独立性,通过平均可以显著降低错误。
我们不介绍bagging和boosting方法,感兴趣的自己去查找。Stacking它并不是一个单独的机器学习算法或者模型,而是将多个结合在一起,单个的模型称为“个体学习器”,如果个体学习器都相同,那么可称为“基学习器”。 个体学习器组合在一起形成的模型,常常能够使得泛化性能提高,这对于“弱学习器”的提高尤为明显。在进行Stacking的时候,我们希望我们的基学习器应该是好而不同,“好”就是说,基学习器不能太差,“不同”就是各个学习器尽量有差异。
二、Stacking算法
我们借助上面的流程图来看一下Stacking算法是怎样进行的,我们借助一些实例来辅助理解。
我们就以MNIST手写数据集的分类任务为例:
输入:
数据集
D
D
D:60000张训练+10000张测试图片,其中的
x
x
x对应的是像素值矩阵,
y
y
y对应的是图片的类别(0~9的某类)
初级学习算法
L
\mathcal{L}
L:其实就是我们上文提到的基学习器,比如我们选取了VGG、AlexNet以及ResNet三种网络模型对应
L
1
、
L
2
、
L
3
\mathcal{L}_{1}、\mathcal{L}_{2}、\mathcal{L}_{3}
L1、L2、L3
次级学习器我们先空着,继续算法再回头看它是什么
算法流程:
过程1-3 是训练基学习器,在训练集上对于输出
h
1
,
h
2
h_{1},h_{2}
h1,h2和
h
3
h_{3}
h3可以表示已经训练了一个epoch的VGG、AlexNet以及ResNet。
过程5-9是使用训练出来的个体学习器来得预测的结果,我们用 z 1 , z 2 z_{1},z_{2} z1,z2和 z 3 z_{3} z3来表示预测的结果,size都是 [ 50000 , 1 ] [50000,1] [50000,1],我们把它们stack在一起生成 Z , s i z e Z,size Z,size为 [ 50000 , 3 ] [50000,3] [50000,3]当做次级学习器的训练集 D ′ D' D′。
过程11 是用初级学习器预测的结果训练出次级学习器,对于 D ′ D' D′我们要得到最后训练的模型或者训练结果,次级学习器想必你已经理解了起到什么样的作用,例如我们可以选择决策树来作为次级学习器,亦或者我们可以采取投票机制来作为分类的结果,当预测标签的个数超过基学习器数目的一半时就把预测标签当作实际输出。
如果想要预测一个数据的输出,只需要把这条数据用初级学习器预测,然后将预测后的结果用次级学习器预测便可。
代码
 集成学习Stacking的代码很多,这里有一个实例,可以辅助大家理解学习,博客见下面的链接
:https://blog.csdn.net/weixin_43774727/article/details/105532359
我初步改好了自己Stacking模型的框架,感觉效果肯定会有提升的,但是bug没有调完,另外涉及到论文,模型代码没法贴上去,见谅。
==========================================================
我用了三个基学习器集成模型,效果确实有不小的提升。