第03篇:Feature-based 知识蒸馏——中间层特征传递的艺术

本篇带你理解特征蒸馏的基本思路、经典方法(如 FitNet 和 Attention Transfer),并实现一个 PyTorch 示例。

一、什么是 Feature-based 蒸馏?

相较于上一篇介绍的 Soft Target(输出蒸馏),Feature-based 蒸馏强调的是:

让学生模型模仿 教师模型中间层的特征表示

这就好比学生不仅要学会“答对题”(预测分类),还要“思考方式一样”(中间表示一致)。

二、为什么使用 Feature-based 蒸馏?

✅ 优点:
特性说明
更细粒度中间层特征保留更多结构与空间信息
泛化更强早期层能引导学生捕捉关键模式
可用于多任务特征蒸馏适用于目标检测、分割等任务
局限:
  • 教师和学生结构差异大时对齐较难;

  • 特征对齐会增加内存和训练复杂度

三、经典方法简析

① FitNet(Romero et al., 2015)

核心思想:让学生中间层 mimick 教师中间层。

  • 教师层输出:

  • 学生层输出:

使用一个 regressor(卷积层)把 F_s投影到 C_t维度,再与 F_t计算 L2 loss。

✅ 强调“低层语义一致性”。

 

ttention Transfer(Zagoruyko & Komodakis, 2017)

核心思想:不是直接对齐特征,而是对齐“注意力图”。

  • 注意力图计算如下:

即:对通道求平方再求和,得到一个 H×WH \times WH×W 的注意力热图。

  • 蒸馏损失:

 ✅ 更关注“学生关注了哪里”。

四、PyTorch 实践:实现 Attention Transfer 蒸馏

我们以 CIFAR-10 分类为例,教师用 ResNet18,学生为简化网络。

✅ Step 1:定义注意力图函数
def attention_map(feature):
    # feature: [B, C, H, W]
    return torch.norm(feature, p=2, dim=1)  # [B, H, W]

 ✅ Step 2:定义 AT 损失函数

def at_loss(student_feature, teacher_feature):
    student_att = attention_map(student_feature)
    teacher_att = attention_map(teacher_feature)

    student_att = nn.functional.normalize(student_att.view(student_att.size(0), -1), p=2, dim=1)
    teacher_att = nn.functional.normalize(teacher_att.view(teacher_att.size(0), -1), p=2, dim=1)

    return ((student_att - teacher_att)**2).mean()

✅ Step 3:提取中间特征(hook 方式)

student_feats, teacher_feats = [], []

def get_activation(name, feat_list):
    def hook(model, input, output):
        feat_list.append(output)
    return hook

# Hook 中间层
teacher_model.layer2.register_forward_hook(get_activation('t_layer2', teacher_feats))
student_model.features[2].register_forward_hook(get_activation('s_layer2', student_feats))

✅ Step 4:训练流程

for inputs, labels in dataloader:
    student_feats.clear()
    teacher_feats.clear()

    with torch.no_grad():
        _ = teacher_model(inputs)

    outputs = student_model(inputs)
    loss_ce = nn.CrossEntropyLoss()(outputs, labels)
    loss_at = at_loss(student_feats[0], teacher_feats[0])

    loss = alpha * loss_ce + (1 - alpha) * loss_at

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

五、特征蒸馏技巧汇总

建议
蒸馏层选择可选 shallow、middle 或 deep layers,建议做 ablation
匹配方式可用 conv、1×1 conv、MLP 做维度对齐
多层蒸馏多层加权融合损失更稳定
AT vs FitNetAT 对结构兼容性要求低;FitNet 适合有特征对齐机制的结构

六、进阶:融合 Soft + Feature 蒸馏

蒸馏损失可以是组合形式:

total_loss = α * CE + β * KL + γ * AT

 多源信息引导,能进一步提升学生模型的泛化能力。

七、小结

内容说明
⭐ 蒸馏类型Feature-based 提供中间层监督信号
🔧 方法代表FitNet、Attention Transfer
🚀 应用场景分类、检测、分割等多任务模型压缩

下一篇预告

📘 第04篇:Relation-based 蒸馏——建模教师与学生间的“结构关系”

  • 如何建模特征间的空间/语义关系;

  • 介绍 RKD、PKD 等方法;

  • PyTorch 实现实例。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值