【笔记】KL散度教师模型(软标签)与学生模型:Knowledge Distillation中,教师模型和学生模型通常具有相同或相似的架构,但教师模型已经被训练好,而学生模型则从头开始训练,或是进行微调

注1:

KL散度(Kullback-Leibler Divergence),全称是Kullback-Leibler Divergence,是一种用于衡量两个概率分布之间差异的非对称度量。KL散度在信息论和机器学习中广泛使用,包括在知识蒸馏过程中用来衡量学生模型输出与教师模型输出之间的差异。

注2:


KL散度在知识蒸馏中的应用

在知识蒸馏过程中,教师模型生成的软标签提供了关于每个类的概率分布,而不仅仅是一个硬标签(one-hot vector)。学生模型通过最小化其输出与教师模型输出之间的KL散度,来学习教师模型的知识。

注3:

KL 散度的Code eg:

import torch
import torch.nn as nn

# 示例输出:教师模型和学生模型的概率分布
teacher_outputs = torch.tensor([[0.1, 0.9], [0.8, 0.2]], requires_grad=False)  # 教师模型的软标签
student_outputs = torch.tensor([[0.2, 0.8], [0.6, 0.4]], requires_grad=True)  # 学生模型的预测

# 使用KL散度计算损失
criterion = nn.KLDivLoss(reduction='batchmean')  # 使用批次平均的KL散度
loss = criterion(nn.functional.log_softmax(student_outputs, dim=1),
                 nn.functional.softmax(teacher_outputs, dim=1))

print(f'KL散度损失: {loss.item()}')


教师模型和学生模型的输出:

教师模型的输出 teacher_outputs 是生成的软标签,表示每个类的概率分布。
学生模型的输出 student_outputs 是学生模型对输入数据的预测。


计算KL散度损失:

使用 nn.KLDivLoss 计算KL散度损失,reduction='batchmean' 表示对整个批次的损失取平均。
需要注意的是,在计算KL散度时,学生模型的输出需要先通过 log_softmax 函数,教师模型的输出需要通过 softmax 函数。log_softmax 是因为 nn.KLDivLoss 期望输入是对数概率分布。

正文:

知识蒸馏的基本流程
训练教师模型:先训练一个性能较好的大模型,称为教师模型。


使用教师模型生成软标签:用训练好的教师模型对数据进行预测,生成包含各类别概率分布的软标签。


训练学生模型:使用教师模型生成的软标签以及原始标签,同时指导学生模型的训练。训练过程中,学生模型的目标是尽量模仿教师模型的行为。


为什么使用知识蒸馏
知识蒸馏的主要目的是将教师模型中学到的知识传递给一个较小的学生模型,使学生模型在计算资源有限的情况下也能有较好的性能。这在实际应用中非常有用,因为较小的模型在推理时更加高效。

分类损失(loss_cls):使用真实标签和学生模型的分类标记输出计算交叉熵损失。


蒸馏损失(loss_dist):使用教师模型的软标签和学生模型的蒸馏标记输出计算KL散度损失。

import torch
import torch.nn as nn
from collections import OrderedDict
from functools import partial

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        self.blocks = nn.Sequential(*[
            nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim * mlp_ratio), drop_ratio)
            for _ in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if distilled else None

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_vit_weights)

    def _init_vit_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.dist_token is not None:
            dist_tokens = self.dist_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        else:
            x = torch.cat((cls_tokens, x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training:
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x

# 教师模型生成软标签的示例(伪代码)
def teacher_model_forward(teacher_model, x):
    with torch.no_grad():
        return teacher_model(x)

# 示例使用
teacher_model = VisionTransformer()  # 预训练好的教师模型
student_model = VisionTransformer(distilled=True)  # 学生模型
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 假设有一些输入数据和真实标签
input_tensor = torch.randn(8, 3, 224, 224)  # 8个样本的批次
true_labels = torch.randint(0, 1000, (8,))

# 教师模型生成软标签
teacher_outputs = teacher_model_forward(teacher_model, input_tensor)

# 学生模型前向传播
student_outputs, student_dist_outputs = student_model(input_tensor)

# 计算蒸馏损失(结合真实标签和教师模型输出的软标签)
loss_cls = criterion(student_outputs, true_labels)
loss_dist = nn.KLDivLoss()(nn.functional.log_softmax(student_dist_outputs, dim=1),
                           nn.functional.softmax(teacher_outputs, dim=1))
loss = loss_cls + loss_dist

# 反向传播和优化
loss.backward()
optimizer.step()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序猿的探索之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值