注1:
KL散度(Kullback-Leibler Divergence),全称是Kullback-Leibler Divergence,是一种用于衡量两个概率分布之间差异的非对称度量。KL散度在信息论和机器学习中广泛使用,包括在知识蒸馏过程中用来衡量学生模型输出与教师模型输出之间的差异。
注2:
KL散度在知识蒸馏中的应用
在知识蒸馏过程中,教师模型生成的软标签提供了关于每个类的概率分布,而不仅仅是一个硬标签(one-hot vector)。学生模型通过最小化其输出与教师模型输出之间的KL散度,来学习教师模型的知识。
注3:
KL 散度的Code eg:
import torch
import torch.nn as nn
# 示例输出:教师模型和学生模型的概率分布
teacher_outputs = torch.tensor([[0.1, 0.9], [0.8, 0.2]], requires_grad=False) # 教师模型的软标签
student_outputs = torch.tensor([[0.2, 0.8], [0.6, 0.4]], requires_grad=True) # 学生模型的预测
# 使用KL散度计算损失
criterion = nn.KLDivLoss(reduction='batchmean') # 使用批次平均的KL散度
loss = criterion(nn.functional.log_softmax(student_outputs, dim=1),
nn.functional.softmax(teacher_outputs, dim=1))
print(f'KL散度损失: {loss.item()}')
教师模型和学生模型的输出:
教师模型的输出 teacher_outputs 是生成的软标签,表示每个类的概率分布。
学生模型的输出 student_outputs 是学生模型对输入数据的预测。
计算KL散度损失:
使用 nn.KLDivLoss 计算KL散度损失,reduction='batchmean' 表示对整个批次的损失取平均。
需要注意的是,在计算KL散度时,学生模型的输出需要先通过 log_softmax 函数,教师模型的输出需要通过 softmax 函数。log_softmax 是因为 nn.KLDivLoss 期望输入是对数概率分布。
正文:
知识蒸馏的基本流程
训练教师模型:先训练一个性能较好的大模型,称为教师模型。
使用教师模型生成软标签:用训练好的教师模型对数据进行预测,生成包含各类别概率分布的软标签。
训练学生模型:使用教师模型生成的软标签以及原始标签,同时指导学生模型的训练。训练过程中,学生模型的目标是尽量模仿教师模型的行为。
为什么使用知识蒸馏
知识蒸馏的主要目的是将教师模型中学到的知识传递给一个较小的学生模型,使学生模型在计算资源有限的情况下也能有较好的性能。这在实际应用中非常有用,因为较小的模型在推理时更加高效。
分类损失(loss_cls):使用真实标签和学生模型的分类标记输出计算交叉熵损失。
蒸馏损失(loss_dist):使用教师模型的软标签和学生模型的蒸馏标记输出计算KL散度损失。
import torch
import torch.nn as nn
from collections import OrderedDict
from functools import partial
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768):
super(PatchEmbed, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
self.blocks = nn.Sequential(*[
nn.TransformerEncoderLayer(embed_dim, num_heads, int(embed_dim * mlp_ratio), drop_ratio)
for _ in range(depth)
])
self.norm = norm_layer(embed_dim)
if representation_size and not distilled:
self.has_logits = True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
("fc", nn.Linear(embed_dim, representation_size)),
("act", nn.Tanh())
]))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if distilled else None
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.dist_token is not None:
nn.init.trunc_normal_(self.dist_token, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_vit_weights)
def _init_vit_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x).flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.dist_token is not None:
dist_tokens = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
else:
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.training:
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x
# 教师模型生成软标签的示例(伪代码)
def teacher_model_forward(teacher_model, x):
with torch.no_grad():
return teacher_model(x)
# 示例使用
teacher_model = VisionTransformer() # 预训练好的教师模型
student_model = VisionTransformer(distilled=True) # 学生模型
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 假设有一些输入数据和真实标签
input_tensor = torch.randn(8, 3, 224, 224) # 8个样本的批次
true_labels = torch.randint(0, 1000, (8,))
# 教师模型生成软标签
teacher_outputs = teacher_model_forward(teacher_model, input_tensor)
# 学生模型前向传播
student_outputs, student_dist_outputs = student_model(input_tensor)
# 计算蒸馏损失(结合真实标签和教师模型输出的软标签)
loss_cls = criterion(student_outputs, true_labels)
loss_dist = nn.KLDivLoss()(nn.functional.log_softmax(student_dist_outputs, dim=1),
nn.functional.softmax(teacher_outputs, dim=1))
loss = loss_cls + loss_dist
# 反向传播和优化
loss.backward()
optimizer.step()