DeiT:数据高效的图像Transformer及其工作原理详解

DeiT:数据高效的图像Transformer及其工作原理详解

随着Transformer架构在自然语言处理(NLP)领域的巨大成功,研究者们开始探索其在计算机视觉领域的应用。Vision Transformer(ViT)是最早将Transformer直接应用于图像分类的模型之一,但其训练需要依赖大规模数据集(如JFT-300M)和强大的计算资源,这限制了其广泛应用。针对这一问题,Facebook AI和Sorbonne University的研究团队提出了DeiT(Data-efficient image Transformers),一种仅使用ImageNet数据集(约130万张图像)即可高效训练的图像Transformer模型。本文将详细介绍DeiT的原理,特别针对熟悉Transformer结构的深度学习研究者,深入探讨其架构设计、训练策略以及创新的蒸馏方法。

下文中图片来自于原论文:https://arxiv.org/pdf/2012.12877


一、DeiT的核心思想与背景

DeiT的目标是解决ViT的一个关键问题:Transformer在视觉任务中对数据量的依赖性。ViT的研究表明,如果仅使用ImageNet这样的中小规模数据集,Transformer模型的性能会显著低于卷积神经网络(CNN)。DeiT通过优化训练策略和引入特定的知识蒸馏方法,成功地在单台8-GPU机器上(训练时间2-3天)实现了与CNN竞争的性能,其参考模型(DeiT-B,86M参数)在ImageNet上达到了83.1%的top-1准确率(单裁剪),甚至在蒸馏后最高可达85.2%。

DeiT的核心贡献包括:

  1. 数据高效训练:通过强数据增强和正则化策略,使Transformer在ImageNet-only的场景下也能表现出色。
  2. 新型蒸馏方法:提出了一种专为Transformer设计的“蒸馏token”策略,利用教师模型(可以是CNN或Transformer)的知识进一步提升学生模型性能。
  3. 迁移能力:在下游任务(如CIFAR、iNaturalist等)上表现出与CNN相当的泛化能力。

接下来,我们将从架构、训练和蒸馏三个方面详细剖析DeiT的原理。


二、DeiT的架构设计

DeiT的架构直接继承自ViT,因此熟悉Transformer的研究者可以快速理解其结构。以下是其关键组件的详细说明:

1. 输入处理:图像分块与嵌入

与ViT相同,DeiT将输入图像(固定分辨率,如224×224)分割成固定大小的patch(通常为16×16像素)。对于224×224的图像,这会生成 ( 14 × 14 = 196 14 \times 14 = 196 14×14=196 ) 个patch。每个patch被展平并通过线性层投影到一个固定维度(例如DeiT-B中为768维),从而形成patch嵌入序列:

  • 输入图像 ( X ∈ R H × W × 3 X \in \mathbb{R}^{H \times W \times 3} XRH×W×3 ) → ( N N N ) 个patch(( N = H P × W P N = \frac{H}{P} \times \frac{W}{P} N=PH×PW )),其中 ( P = 16 P=16 P=16 )。
  • 线性投影:( X p ∈ R N × ( P 2 ⋅ 3 ) → Z 0 ∈ R N × D X_p \in \mathbb{R}^{N \times (P^2 \cdot 3)} \rightarrow Z_0 \in \mathbb{R}^{N \times D} XpRN×(P23)Z0RN×D )。
2. 位置编码(Positional Encoding)

由于Transformer缺乏CNN的空间归纳偏置(inductive bias),需要显式加入位置信息。DeiT沿用ViT的做法,在patch嵌入后添加位置编码(可以是固定的正弦编码或可学习的参数),以保留patch的空间关系:

  • ( Z 0 = [ z p a t c h 1 , z p a t c h 2 , . . . , z p a t c h N ] + E p o s Z_0 = [z_{patch_1}, z_{patch_2}, ..., z_{patch_N}] + E_{pos} Z0=[zpatch1,zpatch2,...,zpatchN]+Epos ),其中 ( E p o s ∈ R N × D E_{pos} \in \mathbb{R}^{N \times D} EposRN×D )。

一个改进点是,DeiT支持在不同分辨率下微调模型(例如从224×224训练到384×384微调)。此时,patch数量 ( N N N ) 会变化,DeiT通过插值(通常为双三次插值)调整位置编码的大小,确保模型适配性。

3. 类token(Class Token)

为了进行分类,DeiT在patch序列前添加一个可学习的类token(class token),其作用类似于NLP中BERT的[CLS] token。类token与patch token一起通过Transformer层处理,最终在最后一层通过线性分类器预测类别:

  • 输入序列:( Z 0 = [ z c l a s s , z p a t c h 1 , . . . , z p a t c h N ] Z_0 = [z_{class}, z_{patch_1}, ..., z_{patch_N}] Z0=[zclass,zpatch1,...,zpatchN] )。
4. Transformer块

DeiT的Transformer块与标准结构一致,每个块包括:

  • 多头自注意力(Multi-head Self-Attention, MSA)
    • ( Attention ( Q , K , V ) = Softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d}})V Attention(Q,K,V)=Softmax(d QKT)V ),其中 ( Q , K , V Q, K, V Q,K,V ) 由输入序列线性变换生成。
    • 多头机制通过 ( h h h ) 个并行注意力头增强表达能力(DeiT-B中 ( h = 12 h=12 h=12 ))。
  • 前馈网络(Feed-Forward Network, FFN)
    • 两层MLP,中间使用GeLU激活,第一层将维度扩展到 ( 4D ),第二层还原到 ( D D D )。
  • 残差连接与层归一化(LayerNorm)
    • ( Z ′ = MSA ( LayerNorm ( Z ) ) + Z Z' = \text{MSA}(\text{LayerNorm}(Z)) + Z Z=MSA(LayerNorm(Z))+Z )。
    • ( Z = FFN ( LayerNorm ( Z ′ ) ) + Z ′ Z = \text{FFN}(\text{LayerNorm}(Z')) + Z' Z=FFN(LayerNorm(Z))+Z )。

DeiT-B由12个Transformer块组成,嵌入维度 ( D = 768 D=768 D=768 ),每头维度 ( d = D / h = 64 d = D/h = 64 d=D/h=64 )。

5. 输出层

最后一层的类token经过线性层投影到类别数(如ImageNet的1000类),输出logits用于分类。


三、数据高效训练策略

由于Transformer缺乏CNN的局部性偏置,其训练需要更多数据或更强的正则化。DeiT通过以下策略实现了数据高效性:

1. 强数据增强

DeiT大量借鉴CNN的增强技术,包括:

  • Rand-Augment:随机选择增强操作(如旋转、剪切等),参数为9/0.5。
  • Mixup(概率0.8):混合两张图像及其标签。
  • CutMix(概率1.0):将一张图像的部分替换为另一张图像。
  • 随机擦除(Random Erasing)(概率0.25):随机遮挡图像区域。

这些增强显著增加了数据的多样性,帮助Transformer在有限数据下学习鲁棒特征。

2. 正则化与优化
  • Stochastic Depth(概率0.1):随机丢弃Transformer块,增强深层网络的训练稳定性。
  • Label Smoothing(( ε = 0.1 \varepsilon=0.1 ε=0.1)):平滑标签分布,减少过拟合。
  • 优化器:使用AdamW(学习率 ( 5 × 1 0 − 4 × batchsize 512 5 \times 10^{-4} \times \frac{\text{batchsize}}{512} 5×104×512batchsize )),权重衰减0.05,配合余弦学习率衰减和5个epoch的warmup。
3. 重复增强(Repeated Augmentation)

DeiT采用重复增强策略,即对同一张图像多次应用不同的增强变换(通常3次),增加训练时的样本多样性。这一策略显著提升了性能,尤其在300 epoch的训练中效果明显。

4. 分辨率调整

DeiT首先在224×224分辨率下预训练(约53小时,8-GPU),然后在更高分辨率(如384×384)微调(约20小时)。微调时通过插值调整位置编码,保持模型一致性。


四、创新的蒸馏方法:Distillation Token

DeiT的一个亮点是提出了专为Transformer设计的蒸馏策略,通过引入蒸馏token(distillation token)增强学生模型的学习。以下是其原理和实现细节:

1. 传统知识蒸馏回顾

传统知识蒸馏(Knowledge Distillation, KD, 具体可以参考笔者的另一篇博客:Hinton提出的知识蒸馏(Knowledge Distillation,简称KD):原理解释和代码实现)通过教师模型的软标签(soft labels)指导学生模型:

  • 软蒸馏:学生模型最小化其softmax输出与教师softmax输出之间的KL散度:
    • ( L global = ( 1 − λ ) L CE ( ψ ( Z s ) , y ) + λ τ 2 KL ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) \mathcal{L}_{\text{global}} = (1-\lambda) \mathcal{L}_{\text{CE}}(\psi(Z_s), y) + \lambda \tau^2 \text{KL}(\psi(Z_s/\tau), \psi(Z_t/\tau)) Lglobal=(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ)) ),
    • 其中 ( Z s , Z t Z_s, Z_t Zs,Zt ) 为学生和教师的logits,( τ \tau τ) 为温度,( λ \lambda λ) 为平衡因子。

然而,DeiT发现软蒸馏对Transformer的效果不如预期,因此提出了两种改进:

  • 硬蒸馏:直接使用教师的硬标签(( y t = argmax ( Z t ) y_t = \text{argmax}(Z_t) yt=argmax(Zt) ))作为监督信号:
    • ( L global hardDistill = 1 2 L CE ( ψ ( Z s ) , y ) + 1 2 L CE ( ψ ( Z s ) , y t ) \mathcal{L}_{\text{global}}^{\text{hardDistill}} = \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(Z_s), y) + \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(Z_s), y_t) LglobalhardDistill=21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt) )。

硬蒸馏在Transformer上表现更好,因为它简单且无需调参,同时能适应数据增强带来的标签变化。

2. 蒸馏Token的设计

DeiT进一步提出了一种Transformer特有的蒸馏方法:

  • 在输入序列中添加一个额外的蒸馏token,与类token并存:
    • 输入序列变为:( Z 0 = [ z c l a s s , z d i s t i l l , z p a t c h 1 , . . . , z p a t c h N ] Z_0 = [z_{class}, z_{distill}, z_{patch_1}, ..., z_{patch_N}] Z0=[zclass,zdistill,zpatch1,...,zpatchN] )。
  • 蒸馏token通过自注意力机制与patch token和类token交互,目标是重现教师模型的预测(硬标签 ( y t y_t yt ))。
  • 在最后一层,蒸馏token通过独立的线性分类器输出预测,与类token的输出互补。

在这里插入图片描述

3. 训练与推理
  • 训练时:损失函数结合真标签(作用于类token)和教师标签(作用于蒸馏token),两者的权重相等。
  • 推理时:可以单独使用类token或蒸馏token的分类器,也可以融合两者(late fusion,softmax输出相加),融合方式通常效果最佳。
4. 为什么有效?
  • 互补性:实验表明,类token和蒸馏token在训练后收敛到不同的向量(初始余弦相似度0.06,最后层0.93),表明它们捕获了不同信息。
  • 教师偏置:当使用CNN(如RegNetY-16GF)作为教师时,蒸馏token能引入卷积的局部性偏置,使Transformer受益于CNN的归纳能力。
  • 性能提升:相比传统硬蒸馏,DeiT的蒸馏token方法将准确率从83.0%提升至84.5%(DeiT-B,224分辨率)。

五、实验结果与分析
1. ImageNet性能
  • DeiT-B(无蒸馏):83.1% top-1(384分辨率)。
  • DeiT-B蒸馏(DeiT-B(\pi)):85.2% top-1,超越EfficientNet和JFT-300M预训练的ViT-B。
  • 小模型:DeiT-S(22M参数)和DeiT-Ti(5M参数)分别达到81.2%和74.5%,适合资源受限场景。
2. 与CNN的对比

DeiT在吞吐量(images/sec)与准确率的权衡上接近EfficientNet,尤其在蒸馏后甚至超越,显示出Transformer的潜力。

3. 迁移学习

在CIFAR-10、Flowers-102等任务上,DeiT的top-1准确率(如99.1%、98.9%)与CNN相当,证明其泛化能力。

4. 教师选择的影响

使用CNN(如RegNetY-16GF,82.9%准确率)作为教师比Transformer更有效,可能是因为CNN的偏置对Transformer的训练更有指导意义。


六、总结与展望

DeiT通过优化训练策略和创新的蒸馏token方法,成功地将Transformer引入数据受限的视觉任务中,其性能已接近甚至超越经过多年优化的CNN。未来研究可以探索:

  • 针对Transformer的专用数据增强方法。
  • 更高效的架构设计,进一步降低计算复杂度。
  • 在更大规模任务(如检测、分割)中的应用。

对于熟悉Transformer的研究者来说,DeiT提供了一个高效的起点,其开源代码(https://github.com/facebookresearch/deit)也便于复现和扩展实验。DeiT不仅是视觉Transformer的一个里程碑,也预示着Transformer可能成为计算机视觉的主流范式之一。

DeiT(Data-efficient image Transformers)的示例代码

以下是基于PyTorch实现的DeiT(Data-efficient image Transformers)的示例代码,包括训练代码和推理代码。由于DeiT的完整实现涉及较多细节(例如数据增强、蒸馏策略等),将提供一个简化的版本,重点展示其核心结构和逻辑。完整的实现可以参考官方代码库(https://github.com/facebookresearch/deit)。

前提条件

  • PyTorch 1.7+
  • torchvision
  • timm(可选,用于预训练模型和增强)
  • 数据集:这里以ImageNet为例,假设你已准备好数据加载器。

1. DeiT模型定义

首先定义DeiT的核心模型结构,基于ViT并添加蒸馏token。

import torch
import torch.nn as nn
import torch.nn.functional as F

class DeiT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., drop_rate=0.1):
        super().__init__()
        
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Positional encoding
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 2, embed_dim))  # +2 for cls and distill tokens
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # Class and distillation tokens
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.distill_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, 
                                      dim_feedforward=int(embed_dim * mlp_ratio), dropout=drop_rate)
            for _ in range(depth)
        ])
        
        # Layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification heads
        self.head = nn.Linear(embed_dim, num_classes)
        self.head_distill = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.distill_token, std=0.02)

    def forward(self, x, return_both=False):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        
        # Add class and distillation tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        distill_tokens = self.distill_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, distill_tokens, x), dim=1)  # [B, num_patches + 2, embed_dim]
        
        # Add positional encoding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        
        # Extract class and distillation tokens
        cls_output = x[:, 0]  # [B, embed_dim]
        distill_output = x[:, 1]  # [B, embed_dim]
        
        # Classification
        cls_logits = self.head(cls_output)
        distill_logits = self.head_distill(distill_output)
        
        if return_both:
            return cls_logits, distill_logits
        return cls_logits  # 默认返回cls token的输出

# 示例模型实例化
model = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)

2. 训练代码

以下是训练DeiT的示例代码,包含蒸馏逻辑和常见数据增强。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm  # 用于加载教师模型

# 数据增强和加载器
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 假设ImageNet数据集路径为 './data/imagenet'
train_dataset = datasets.ImageFolder('./data/imagenet/train', transform=train_transforms)
val_dataset = datasets.ImageFolder('./data/imagenet/val', transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 学生模型
student = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)
student = student.to(device)

# 教师模型(预训练的CNN,例如RegNetY-16GF)
teacher = timm.create_model('regnety_160', pretrained=True, num_classes=1000)
teacher = teacher.to(device)
teacher.eval()  # 教师模型固定

# 损失函数和优化器
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(student.parameters(), lr=5e-4, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)  # 300 epochs

# 训练循环
def train_epoch(student, teacher, loader, optimizer, criterion, epoch):
    student.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # 前向传播
        cls_logits, distill_logits = student(images, return_both=True)
        with torch.no_grad():
            teacher_logits = teacher(images)
            teacher_labels = teacher_logits.argmax(dim=1)  # 硬标签
        
        # 损失:真标签(cls token)+ 教师硬标签(distill token)
        loss_cls = criterion(cls_logits, labels)
        loss_distill = criterion(distill_logits, teacher_labels)
        loss = 0.5 * loss_cls + 0.5 * loss_distill
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch+1}, Batch {i+1}] Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# 验证函数
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)  # 使用cls token输出
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100. * correct / total

# 训练主循环
num_epochs = 300
for epoch in range(num_epochs):
    train_epoch(student, teacher, train_loader, optimizer, criterion, epoch)
    val_acc = evaluate(student, val_loader)
    print(f'Epoch {epoch+1}, Validation Accuracy: {val_acc:.2f}%')
    scheduler.step()

# 保存模型
torch.save(student.state_dict(), 'deit_b.pth')

3. 推理代码

推理代码用于加载训练好的模型并对单张图像进行预测。

import torch
from PIL import Image
from torchvision import transforms

# 加载模型
model = DeiT(img_size=224, patch_size=16, num_classes=1000, embed_dim=768, depth=12, num_heads=12)
model.load_state_dict(torch.load('deit_b.pth'))
model = model.to(device)
model.eval()

# 图像预处理
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image).unsqueeze(0)  # [1, 3, 224, 224]
    return image.to(device)

# 推理函数
def predict(image_path):
    image = preprocess_image(image_path)
    with torch.no_grad():
        cls_logits, distill_logits = model(image, return_both=True)
        # 融合预测(late fusion)
        probs = F.softmax(cls_logits, dim=1) + F.softmax(distill_logits, dim=1)
        _, predicted = probs.max(1)
    return predicted.item()

# 示例推理
image_path = 'example.jpg'
pred_class = predict(image_path)
print(f'Predicted class: {pred_class}')

注意事项

  1. 完整实现:上述代码是简化版,未包含所有训练细节(如Rand-Augment、Repeated Augmentation等)。建议参考官方代码(deit/main.py)获取完整训练流程。
  2. 教师模型:这里使用预训练的RegNetY-16GF作为教师,你可以替换为其他模型(如EfficientNet)。
  3. 硬件需求:训练DeiT-B需要至少8GB显存的GPU,建议使用多GPU加速。
  4. 微调:若需在更高分辨率(如384×384)微调,需调整位置编码并重新定义数据加载器。

获取预训练模型

如果你不想从头训练,可以直接从timm库加载预训练的DeiT模型:

import timm
model = timm.create_model('deit_base_patch16_224', pretrained=True)

希望这些代码能帮助你快速上手DeiT的实现!如需更深入的定制或优化,请参考官方文档和论文中的超参数设置。

后记

2025年3月22日16点15分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值