经典网络—DeiT:高效的数据蒸馏视觉Transformer
1. 引言
近年来,Transformer 在计算机视觉(CV)领域取得了显著突破。ViT(Vision Transformer)虽然具有出色的性能,但对大规模数据依赖较强,导致训练成本较高。Facebook AI 提出的 DeiT(Data-efficient Image Transformer) 通过 数据蒸馏(Distillation) 方法,显著降低了对大规模数据的依赖,使得 Transformer 在 CV 任务中更高效。
2. DeiT 简介
DeiT 由 Facebook AI 研究团队在论文"Training data-efficient image transformers & distillation through attention"中提出。它主要解决 ViT 需要大量数据预训练 的问题,关键方法包括:
- 数据蒸馏(Distillation through Attention)
- 蒸馏 Token(Distillation Token)
- 高效的训练策略
3. DeiT 关键技术
3.1 数据蒸馏(Distillation through Attention)
传统知识蒸馏(Knowledge Distillation)通常用于 CNN 模型,其中一个训练好的教师网络(Teacher)指导较小的学生网络(Student)。但在 DeiT 中,蒸馏方法被 无缝集成到 Transformer 训练过程 中,并引入了 蒸馏 Token(Distillation Token)。
3.2 蒸馏 Token(Distillation Token)
在 ViT 结构中,输入图像被分割成多个 Patch,并添加 CLS Token 进行分类。而在 DeiT 中,
- 额外引入了 Distillation Token
- 它与 CLS Token 共同参与注意力机制
- 通过 教师网络(通常是 CNN) 进行监督
这种方法 不增加计算量,却能提升模型的样本利用效率。
3.3 DeiT 训练策略
- 优化器:AdamW
- 数据增强:RandAugment、Mixup、CutMix
- 正则化:Stochastic Depth、Label Smoothing
- 学习率策略:Cosine Decay
这些优化方法,使 DeiT 不依赖大规模数据 也能取得良好效果。
4. DeiT 代码示例
可以使用 timm
库快速调用预训练模型:
import torch
import timm
# 加载 DeiT 模型(DeiT-small 预训练)
model = timm.create_model('deit_small_distilled_patch16_224', pretrained=True)
model.eval()
# 测试一张随机图片
x = torch.randn(1, 3, 224, 224)
y = model(x)
print(y.shape) # 输出类别预测
如果想要 微调 DeiT,可以使用 PyTorch 训练自己的数据集。
5. DeiT 与 ViT 对比
特性 | ViT | DeiT |
---|---|---|
依赖数据 | 需要大规模数据 | 可在小数据集上训练 |
蒸馏机制 | 无 | 有(Distillation Token) |
计算开销 | 高 | 较低 |
训练策略 | 标准训练 | 数据增强 + 知识蒸馏 |
6. 结论
DeiT 通过 数据蒸馏、蒸馏 Token 和优化训练策略,成功减少了 ViT 对大规模数据的依赖,使 Transformer 在计算机视觉任务中更高效。对于资源受限的研究者或企业,DeiT 提供了 高效的 Transformer 解决方案,可用于分类、目标检测等任务。
你对 DeiT 有什么看法?欢迎在评论区交流! 🎯