Lightweight Alpha Matting Network Using Distillation-Based Channel Pruning
使用基于蒸馏通道裁剪的轻量Alpha抠图网络
https://arxiv.org/pdf/2210.07760v1.pdf
https://github.com/DongGeun-Yoon/DCP
摘要
最近,alpha抠图由于在自拍等移动应用中的有用性而受到了广泛关注。由于移动设备的计算资源有限,一个轻量的alpha抠图模型尤为重要。基于此,我们提出了基于蒸馏的通道剪枝网络。在剪枝过程中,我们去掉了学生网络中对模仿教师网络知识影响较小的通道。然后,裁剪后的轻量学生网络通过相同的蒸馏损失训练。实验表明,该方法优于其它轻量模型。
动机
传统的抠图方法包括Affinity-based方法和Sampling-based方法,目前基于深度学习的方法要优于传统方法。这些大多使用U-net或FCN这种encoder-decoder的结构。通过增加通道数和在基线网络中增加辅助模块提升模型的结果,但是这会增加计算量和内存,从而不利于在移动设备上的应用。将教师网络的特征相似性转移到学生网络上,使学生网络比从零开始训练的基线学生网络取得了更好的性能。为了在通道修剪过程中专注于低层次的精细细节,我们借用了预先训练的优异抠图网络的性能,它很好地保留了精细细节。
在修剪路径中,我们使用了批归一化(BN)层的比例因子的稀疏性,并使用强劲的预训练教师模型应用蒸馏损失,该模型能够精确引导学生网络,在其预测中保留精细的结构细节。在训练路径中,我们用修剪阶段使用的相同的蒸馏损失训练修剪后的轻量级网络。
创新点
(1)针对自然图像抠图问题,提出了一种新的通道修剪方法。
(2)通过在通道裁剪步骤中利用蒸馏损失,成功地找到了一个轻量级的alpha抠图网络。
方法论
(1)通道数裁剪
将浅层特征和深层特征的通道数都减半,得到了表1。可以看出,深层特征的通道数裁剪后,参数更少,而且性能更优。也就是说,浅层的通道数更加重要。
(2)知识蒸馏
L F L_F LF表示相似性函数, θ t ( ⋅ ) θ_t(·) θt(⋅)和 θ s ( ⋅ ) θ_s(·) θs(⋅)分别表示教师模型和学生模型的特征转化函数
(3)基于知识蒸馏的剪枝
使用稀疏损失和蒸馏损失训练一个目标学生模型,通过BN层的尺度因子去除通道数。在网络修剪中,只更新学生网络的参数,固定教师网络的参数。最后的损失包括alpha预测损失,通道稀疏损失和蒸馏损失。
其中
(4)知识蒸馏训练
通过知识蒸馏的方式,通过教师模型对已经裁剪后的学生模型进行训练,损失函数定义如下:
结果
从表格看出,在两个数据集上的效果都挺好。
总结
总的思路就是先通过知识蒸馏剪枝,即图2中的上半部分的网络。再通过知识蒸馏训练,将裁剪后的学生网络跟原始的教师模型训练。其中使用的教师模型的参数一直是固定。总体而言,方法比较简单,写得比较通俗易懂。