2015年由Hinton提出这个概念
知识蒸馏与模型压缩的关系
知识蒸馏是模型压缩的一种方法
模型压缩还有其他方法,如低秩近似(low-rank Approximation),网络剪枝(network pruning),网络量化(network quantization)等
Hard-target 和 Soft-target
soft target相对于hard target,携带更多更多有用的信息
其中 Pi 是每个类别输出的概率,Zi 是每个类别输出的 logits,T 就是温度。当温度 T=1 时,这就是标准的 Softmax 公式。 T越高,softmax 的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
知识蒸馏起初是被用于分类任务的
目前大多数知识蒸馏都是用于分类任务的
学生学到的知识叫做"dark knowledge"
student align knowledge with teacher
Pytorch实现
简单实现,主要为了理解其原理
import torch import torch.nn as nn import numpy as np import ipdb from torch.nn import CrossEntropyLoss from torch.utils.data import TensorDataset,DataLoader,SequentialSampler class model(nn.Module): def __init__(self,input_dim,hidden_dim,output_dim): super(model,self).__init__() self.layer1 = nn.LSTM(input_dim,hidden_dim,output_dim,batch_first = True) self.layer2 = nn.Linear(hidden_dim,output_dim) def forward(self,inputs): layer1_output, layer1_hidden = self.layer1(inputs) layer2_output = self.layer2(layer1_output) layer2_output = layer2_output[:,-1,:]#取出一个batch中每个句子最后一个单词的输出向量即该句子的语义向量!!!!!!!!! return layer2_output #建立小模型 model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4) #建立大模型(此处仍然使用LSTM代替,可以使用训练好的BERT等复杂模型) model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4) #设置输入数据,此处只使用随机生成的数据代替 inputs = torch.randn(4,6,2) #[bs,l,dim] true_label = torch.tensor([0,1,0,0]) #生成dataset dataset = TensorDataset(inputs,true_label) #生成dataloader sampler = SequentialSampler(inputs) dataloader = DataLoader(dataset = dataset,sampler = sampler,batch_size = 2) loss_fun = CrossEntropyLoss() criterion = nn.KLDivLoss()#KL散度 optimizer = torch.optim.SGD(model_student.parameters(),lr = 0.1,momentum = 0.9)#优化器,优化器中只传入了学生模型的参数,因此此处只对学生模型进行参数更新,正好实现了教师模型参数不更新的目的 for step,batch in enumerate(dataloader): inputs = batch[0] #[bs,l,dim] labels = batch[1] #[bs] # ipdb.set_trace() #分别使用学生模型和教师模型对输入数据进行计算 output_student = model_student(inputs) #[bs,dim'] output_teacher = model_teacher(inputs) #[bs,dim'] #计算学生模型预测结果和教师模型预测结果之间的KL散度 loss_soft = criterion(output_student,output_teacher) #计算学生模型和真实标签之间的交叉熵损失函数值 loss_hard = loss_fun(output_student,labels) loss = 0.9*loss_soft + 0.1*loss_hard print(loss) optimizer.zero_grad() loss.backward() optimizer.step()