知识蒸馏(Knowledge Distillation)的Pytorch实现以及分析

       知识蒸馏(Knowledge Distillation)的概念由Hinton大神于2015年在论文《Distilling the Knowledge in a Neural Network》中提出,论文见:https://arxiv.org/abs/1503.02531。此方法的主要思想为:通过结构复杂、计算量大但是性能优秀的教师神经网络,对结构相对简单、计算量较小的学生神经网络进行指导,以提升学生神经网络的性能。论文中提出了“暗知识”这一概念,即:比如我们在识别一张猫猫的图片的时候,一个性能良好的神经网络经过softmax变换后的输出,在一般该向量中代表猫猫的位置会得到一个非常高的值,比如,0.9,而代表其它分类的值在传统的研究中就不那么受重视了。Hinton大神认为,其它位置得到的值能够为我们提供一些额外的信息,比如,在猫得到0.9的同时,识别为狮子的值可能因为相似的缘故给到了0.09,而识别为汽车的值则可能只有0.0001。在我的理解中,这种目标间的相似性,就是“暗知识”的本质。为了要放大这种“暗知识”所包含的信息,Hinton在传统的softmax函数中加入温度参数T,变为下式所示:

                                                                        

       那么,知识蒸馏的步骤分别为:

一、采用传统方式训练一个教师网络。

二、建立学生网络模型,模型的输出采用传统的softmax函数,拟合目标为one-hot形式的训练集输出,它们之间的距离记为loss1。

三、将训练完成的教师网络的softmax分类器加入温度参数,作为具有相同温度参数softmax分类器的学生网络的拟合目标,他们之间的距离记为loss2。

四、引入参数alpha,将loss1×(1-alpha)+loss2×alpha作为网络训练时使用的loss,训练网络。

       重点就在于将暗知识放大之后,让学生网络的暗知识去拟合教师网络的暗知识,同时由于教师网络会带有一定的bias,表现为教师网络在训练完成后,对训练集识别的正确率会高于测试集,所以加上loss1来减缓这种趋势,实际应用的时候,可以考虑将alpha首先设置的接近1,比如0.95啥的,来快速拟合教师网络,再逐步调低alpha的值,来确保网络的分类正确率,不过这只是理论上可行的,我也没试验就是了……

       那我们就开搞啦,首先是搭建教师网络,我这里选择的是resnet18,并且由于电脑训练速度的原因(渣机无力……)将网络中所有卷积核的数目减少了一半,训练集采用Cifar10,训练时对图像进行了padding之后随机裁剪以及随机水平翻转来加入噪声。优化器采用带动量项的SGD(lr=0.1, momentum=0.9, weight_decay=5e-4),训练200个epoch,其中在第100以及第150个epoch时将学习率除10,详细的代码见文章末尾的github地址好啦。训练完成后,网络对测试集的识别结果如下所示:

Accuracy of the network on the 10000 test images: 93.970000 %
Accuracy of plane : 97.727273 %
Accuracy of   car : 100.000000 %
Accuracy of  bird : 84.210526 %
Accuracy of   cat : 86.046512 %
Accuracy of  deer : 93.877551 %
Accuracy of   dog : 96.875000 %
Accuracy of  frog : 98.113208 %
Accuracy of horse : 93.750000 %
Accuracy of  ship : 95.833333 %
Accuracy of truck : 100.000000 %

       这结果当然并不算特别好,所以作为学生的网络,得选个效果比较差的,这样才能体现出教师的价值对吧(笑)。这里我们就简单的架一个三层卷积神经网络作为学生网络好啦,网络具体结构见github。还是使用cifar10经过相同的图像变换过程后,采用adam(lr=0.001)作为优化器对网络训练100个epoch,在完全相同的条件下训练四次,测试集识别结果分别如下,我们可以看到,这几次的训练结果平均一下大概在84%左右。

第一次训练结果:
Accuracy of the network on the 10000 test images: 84.350000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 72.093023 %
Accuracy of  deer : 83.673469 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 94.339623 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 83.333333 %
Accuracy of truck : 93.103448 %
第二次训练结果:
Accuracy of the network on the 10000 test images: 83.870000 %
Accuracy of plane : 97.727273 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 76.744186 %
Accuracy of  deer : 91.836735 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 96.551724 %
第三次训练结果:
Accuracy of the network on the 10000 test images: 84.760000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 68.421053 %
Accuracy of   cat : 72.093023 %
Accuracy of  deer : 83.673469 %
Accuracy of   dog : 84.375000 %
Accuracy of  frog : 90.566038 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 86.206897 %
第四次训练结果:
Accuracy of the network on the 10000 test images: 84.240000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 81.578947 %
Accuracy of   cat : 74.418605 %
Accuracy of  deer : 77.551020 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 83.018868 %
Accuracy of horse : 93.750000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 82.758621 %

       接下来,因为之前看到网上有人说,教师网络本身在训练的时候,是有采用加噪数据进行训练的,所以它的输出的暗知识在理论上可能会包含有噪声项的信息,我们就先在不对数据集进行变换的情况下进行训练。这里我们选取alpha=0.95,T选取2和10分别训练两次,结果如下。我们可以看到,其训练的结果比之前的方法是要差的,这可能是因为学生网络还是直接过拟合了教师网络的输出,所以导致测试集正确率较低。

T=2第一次训练结果:
Accuracy of the network on the 10000 test images: 79.110000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 68.421053 %
Accuracy of   cat : 67.441860 %
Accuracy of  deer : 69.387755 %
Accuracy of   dog : 68.750000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 86.206897 %
T=2第二次训练结果:
Accuracy of the network on the 10000 test images: 76.720000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 62.790698 %
Accuracy of  deer : 73.469388 %
Accuracy of   dog : 59.375000 %
Accuracy of  frog : 77.358491 %
Accuracy of horse : 81.250000 %
Accuracy of  ship : 83.333333 %
Accuracy of truck : 79.310345 %
T=10第一次训练结果:
Accuracy of the network on the 10000 test images: 78.600000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 62.790698 %
Accuracy of  deer : 75.510204 %
Accuracy of   dog : 62.500000 %
Accuracy of  frog : 83.018868 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 86.206897 %
T=10第二次训练结果:
Accuracy of the network on the 10000 test images: 76.550000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 73.684211 %
Accuracy of   cat : 67.441860 %
Accuracy of  deer : 75.510204 %
Accuracy of   dog : 62.500000 %
Accuracy of  frog : 86.792453 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 75.862069 %

       最后是对图片进行了相应的变换加入噪声后,对学生网络进行训练,结果如下:

T=2第一次训练结果:
Accuracy of the network on the 10000 test images: 85.190000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 96.875000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 83.720930 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 84.375000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 75.000000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 93.103448 %
T=2第二次训练结果:
Accuracy of the network on the 10000 test images: 84.490000 %
Accuracy of plane : 93.181818 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 73.684211 %
Accuracy of   cat : 76.744186 %
Accuracy of  deer : 85.714286 %
Accuracy of   dog : 78.125000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 89.655172 %
T=10第一次训练结果:
Accuracy of the network on the 10000 test images: 85.310000 %
Accuracy of plane : 100.000000 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 83.720930 %
Accuracy of  deer : 87.755102 %
Accuracy of   dog : 75.000000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 93.750000 %
Accuracy of truck : 89.655172 %
T=10第二次训练结果:
Accuracy of the network on the 10000 test images: 85.370000 %
Accuracy of plane : 95.454545 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 76.315789 %
Accuracy of   cat : 74.418605 %
Accuracy of  deer : 85.714286 %
Accuracy of   dog : 78.125000 %
Accuracy of  frog : 88.679245 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 87.500000 %
Accuracy of truck : 89.655172 %

       虽然测试集的正确率具有一定程度的不确定性,我们还是可以看出,测试集正确率相比原始的训练方法有所提升。这也可以大致说明这种方法的有效性。当然,这种训练方式目前也产生了很多的变体,比如再生网络等等、

       最后是相关程序与训练完成的网络参数文件的github地址:https://github.com/PolarisShi/distillation

  • 14
    点赞
  • 132
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
知识蒸馏是一种模型压缩技术,通过将一个复杂的模型(教师模型)的知识转移到一个简化的模型(学生模型)中,从而提高学生模型的性能。在PyTorch中,可以使用以下步骤实现知识蒸馏: 1. 定义教师模型和学生模型:首先,需要定义一个教师模型和一个学生模型。教师模型通常是一个复杂的模型,而学生模型是一个简化的模型。 2. 加载和准备数据集:接下来,需要加载和准备用于训练的数据集。这包括数据的预处理、划分为训练集和测试集等步骤。 3. 定义损失函数:在知识蒸馏中,通常使用两个损失函数:一个是用于学生模型的普通损失函数(如交叉熵损失),另一个是用于学生模型和教师模型之间的知识蒸馏损失函数(如平均软标签损失)。 4. 定义优化器:选择一个合适的优化器来更新学生模型的参数。常用的优化器包括随机梯度下降(SGD)和Adam。 5. 训练学生模型:使用加载的数据集和定义的损失函数和优化器,通过迭代训练学生模型。在每个训练步骤中,计算学生模型的损失,并根据损失更新学生模型的参数。 6. 应用知识蒸馏:在计算学生模型的损失时,还需要计算教师模型的输出,并使用知识蒸馏损失函数来衡量学生模型和教师模型之间的相似性。通过最小化知识蒸馏损失,学生模型可以从教师模型中获得更多的知识。 7. 评估学生模型:在训练完成后,使用测试集评估学生模型的性能。可以计算准确率、精确率、召回率等指标来评估学生模型的性能。 以下是一个示例代码,演示了如何在PyTorch实现知识蒸馏: ```python import torch import torch.nn as nn import torch.optim as optim # 定义教师模型和学生模型 teacher_model = TeacherModel() student_model = StudentModel() # 加载和准备数据集 train_dataset = ... test_dataset = ... train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) # 定义损失函数 criterion_student = nn.CrossEntropyLoss() criterion_distillation = nn.KLDivLoss() # 定义优化器 optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9) # 训练学生模型 for epoch in range(num_epochs): student_model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs_student = student_model(inputs) outputs_teacher = teacher_model(inputs) # 计算学生模型的损失 loss_student = criterion_student(outputs_student, labels) # 计算知识蒸馏损失 loss_distillation = criterion_distillation(torch.log_softmax(outputs_student, dim=1), torch.softmax(outputs_teacher, dim=1)) # 总损失为学生模型损失和知识蒸馏损失之和 loss = loss_student + alpha * loss_distillation loss.backward() optimizer.step() # 评估学生模型 student_model.eval() with torch.no_grad(): correct = 0 total = 0 for inputs, labels in test_loader: outputs = student_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total print("Accuracy: {:.2f}%".format(accuracy * 100)) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值