Knowledge Distillation损失函数的实现,KL损失+原CE损失
在知识蒸馏(Knowledge Distillation)中,Kullback-Leibler (KL) 散度损失函数通常用于衡量学生模型输出的概率分布与教师模型输出的概率分布之间的差异。通过最小化这种差异,学生模型可以从教师模型中学习更丰富的信息。下面是一个典型的使用 KL 散度损失函数进行知识蒸馏的示例:
1. 导入必要的库
首先,确保你已经安装了 PyTorch:
pip install torch torchvision
然后,导入必要的库:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
2. 定义教师模型和学生模型
假设你已经有一个预训练的教师模型和一个待训练的学生模型:
class TeacherModel(nn.Module):
def __init__(self):
super(TeacherModel, self).__init__()
# 定义教师模型的网络结构
self.fc = nn.Linear(784, 10)
def forward(self, x):
x = x.view(-1, 784)
return self.fc(x)
class StudentModel(nn.Module):
def __init__(self):
super(StudentModel, self).__init__()
# 定义学生模型的网络结构
self.fc = nn.Linear(784, 10)
def forward(self, x):
x = x.view(-1, 784)
return self.fc(x)
3. 定义知识蒸馏的损失函数
在知识蒸馏过程中,我们通常使用一个组合的损失函数,包括传统的交叉熵损失和 KL 散度损失。
不同的超参数可以T和alpha可以得到不同的结果,具体超参数的选取要根据实际情况
def distillation_loss(student_logits, teacher_logits, labels, T, alpha):
"""
计算知识蒸馏的损失函数。
:param student_logits: 学生模型的输出
:param teacher_logits: 教师模型的输出
:param labels: 真实标签
:param T: 温度参数
:param alpha: 权重参数
:return: 组合损失
"""
# 交叉熵损失
hard_loss = F.cross_entropy(student_logits, labels)
# KL 散度损失
soft_loss = F.kl_div(F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean') * (T * T)
# 组合损失
return alpha * soft_loss + (1.0 - alpha) * hard_loss
4. 训练学生模型
定义模型、优化器和训练过程:
# 实例化模型
teacher_model = TeacherModel()
student_model = StudentModel()
# 假设教师模型已经预训练
# teacher_model.load_state_dict(torch.load('teacher_model.pth'))
# 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
# 训练学生模型
def train_student(student_model, teacher_model, train_loader, T, alpha, epochs):
teacher_model.eval() # 设置教师模型为评估模式
student_model.train() # 设置学生模型为训练模式
for epoch in range(epochs):
for data, target in train_loader:
optimizer.zero_grad()
# 前向传播
student_logits = student_model(data)
with torch.no_grad():
teacher_logits = teacher_model(data)
# 计算损失
loss = distillation_loss(student_logits, teacher_logits, target, T, alpha)
# 反向传播和优化
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
# 示例训练数据加载器
train_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(torch.randn(100, 1, 28, 28), torch.randint(0, 10, (100,))),
batch_size=32, shuffle=True)
# 训练模型
train_student(student_model, teacher_model, train_loader, T=2.0, alpha=0.5, epochs=5)
在这个示例中:
T
是蒸馏的温度参数,通常大于1,用于平滑教师模型的输出概率分布。alpha
是权重参数,用于平衡 KL 散度损失和交叉熵损失。
通过这种方式,学生模型可以从教师模型中学习知识,并提高其性能。