基于知识蒸馏Knowledge Distillation模型压缩pytorch实现

在弄懂原理基础上,从本篇博客开始,逐步介绍基于知识蒸馏的增量学习、模型压缩的代码实现。毕竟“纸上得来终觉浅,绝知此事要躬行。”。

先从最经典的Hilton论文开始,先实现基于知识蒸馏的模型压缩。相关原理可以参考博客:https://blog.csdn.net/zhenyu_an/article/details/101646943

既然基本原理是用一个已训练的teacher网络,去教会一个student网络,那首先需要定义这两个网络如下。这里我们采用pytorch语言,以最简单的mnist数据集为例来看看效果。

先定义student网络,由一个卷积层、池化层、全连接层构成,很简单。

class anNet(nn.Module):
    def __init__(self):
        super(anNet,self).__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.pool1 = nn.MaxPool2d(2,1)
        self.fc3 = nn.Linear(3750,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(F.relu(x))
        x = x.view(x.size()[0],-1)
        x = self.fc3(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

再定义一个teacher网络,由两个卷积、两个池化、一个全连接层组成。

class anNet_deep(nn.Module):
    def __init__(self):
        super(anNet_deep,self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(1,64,3,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv2 = nn.Sequential(
                nn.Conv2d(64,64,3,1,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv3 = nn.Sequential(
                nn.Conv2d(64,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.conv4 = nn.Sequential(
                nn.Conv2d(128,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.pooling1 = nn.Sequential(nn.MaxPool2d(2,stride=2))
        self.fc = nn.Sequential(nn.Linear(6272,10))
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pooling1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pooling1(x)
        x = x.view(x.size()[0],-1)
        x = self.fc(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

为了提高teacher网络的性能,在每个卷积层后面加上了BN层。通过print(sum(x.numel() for x in model.parameters()))可以计算出teacher网络和student网络的参数个数分别为:618186和37570,二者相差16倍。

我们首先在mnist数据集上分别对两种模型训练,采用相同的优化方法optimizer = optim.Adam(model.parameters(),lr = 0.001)、相同的损失函数criterion = nn.CrossEntropyLoss()和相同的epoch,teacher网络得到最佳测试正确率大约在0.989至0.991之间,student网络得到的最佳测试正确率大约在0.957至0.959之间。总体而言,在16倍的参数差距面前,student网络干不过teacher网络。把teacher网络训练好的模型保存好。

下面开始知识蒸馏中的关键代码:

知识蒸馏的关键是loss的设计,它包括普通的交叉熵loss1和建立在软目标基础上的loss2。分别如下:

# 损失函数
criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()
# 经典损失
outputs = model(inputs.float())
loss1 = criterion(outputs, labels)
# 蒸馏损失        
teacher_outputs = teach_model(inputs.float())
T = 2
alpha = 0.5
outputs_S = F.log_softmax(outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)
loss2 = criterion2(outputs_S,outputs_T)*T*T

#综合损失结果
loss = loss1*(1-alpha) + loss2*alpha

仔细看这段代码,loss1的设计没有问题,它衡量student网络输出与标准值labels的差距,用的是交叉熵损失。

loss2是关键的蒸馏损失,它的衡量的是student网络输出与已训练好的teacher网络输出,经过软化的结果之间差距。其中outputs_S 代表student网络输出软化后结果,outputs_T 代表teacher网络输出软化后结果,二者采用的是KL散度损失函数。T和alpha是两个超参数,取法对结果影响很大,T的取法一般可以有2,10,20几种,alpha一般取0.5,0.9,0.95几种。需要留意的是这里采用的两种软化方法,student网络输出后加一个log_softmax(outputs/T),teacher网络输出后加一个softmax(参考了https://github.com/PolarisShi/distillation的写法)。这里的问题在于,pytorch源码实现的KL散度是一个阉割版本,并没有对预测结果做log处理,作者试图在这里给补上。事实上,这种写法也没有完全实现标准KL散度的公式,因为漏了log之前的值,最后写成了四不像,反倒是都用softmax的话至少与pytorch的思路一致。待重开一篇博客详细介绍pytorch中的KL散度与交叉熵。

我们把训练好的teacher网络参数导入,开始蒸馏训练,student网络最后精度可以提升到97.02%,没有预想中效果明显,可能是因为超参数的取值不合适。

完整代码和训练好的模型见个人github:https://github.com/azy1988/ML-CV/tree/master/model_distillation

 

  • 19
    点赞
  • 113
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
知识蒸馏是一种模型压缩技术,通过将一个复杂的模型(教师模型)的知识转移到一个简化的模型(学生模型)中,从而提高学生模型的性能。在PyTorch中,可以使用以下步骤实现知识蒸馏: 1. 定义教师模型和学生模型:首先,需要定义一个教师模型和一个学生模型。教师模型通常是一个复杂的模型,而学生模型是一个简化的模型。 2. 加载和准备数据集:接下来,需要加载和准备用于训练的数据集。这包括数据的预处理、划分为训练集和测试集等步骤。 3. 定义损失函数:在知识蒸馏中,通常使用两个损失函数:一个是用于学生模型的普通损失函数(如交叉熵损失),另一个是用于学生模型和教师模型之间的知识蒸馏损失函数(如平均软标签损失)。 4. 定义优化器:选择一个合适的优化器来更新学生模型的参数。常用的优化器包括随机梯度下降(SGD)和Adam。 5. 训练学生模型:使用加载的数据集和定义的损失函数和优化器,通过迭代训练学生模型。在每个训练步骤中,计算学生模型的损失,并根据损失更新学生模型的参数。 6. 应用知识蒸馏:在计算学生模型的损失时,还需要计算教师模型的输出,并使用知识蒸馏损失函数来衡量学生模型和教师模型之间的相似性。通过最小化知识蒸馏损失,学生模型可以从教师模型中获得更多的知识。 7. 评估学生模型:在训练完成后,使用测试集评估学生模型的性能。可以计算准确率、精确率、召回率等指标来评估学生模型的性能。 以下是一个示例代码,演示了如何在PyTorch实现知识蒸馏: ```python import torch import torch.nn as nn import torch.optim as optim # 定义教师模型和学生模型 teacher_model = TeacherModel() student_model = StudentModel() # 加载和准备数据集 train_dataset = ... test_dataset = ... train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) # 定义损失函数 criterion_student = nn.CrossEntropyLoss() criterion_distillation = nn.KLDivLoss() # 定义优化器 optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9) # 训练学生模型 for epoch in range(num_epochs): student_model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs_student = student_model(inputs) outputs_teacher = teacher_model(inputs) # 计算学生模型的损失 loss_student = criterion_student(outputs_student, labels) # 计算知识蒸馏损失 loss_distillation = criterion_distillation(torch.log_softmax(outputs_student, dim=1), torch.softmax(outputs_teacher, dim=1)) # 总损失为学生模型损失和知识蒸馏损失之和 loss = loss_student + alpha * loss_distillation loss.backward() optimizer.step() # 评估学生模型 student_model.eval() with torch.no_grad(): correct = 0 total = 0 for inputs, labels in test_loader: outputs = student_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total print("Accuracy: {:.2f}%".format(accuracy * 100)) ```
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值