import torch
import torch.nn as nn
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset,DataLoader,SequentialSampler
class model(nn.Module):
def __init__(self,input_dim,hidden_dim,output_dim):
super(model,self).__init__()
self.layer1 = nn.LSTM(input_dim,hidden_dim,output_dim,batch_first = True)
self.layer2 = nn.Linear(hidden_dim,output_dim)
#学生模型(2,8)->(8,4)
#教师模型(2,16)->(16,4)
def forward(self,inputs):
layer1_output,layer1_hidden = self.layer1(inputs)
layer2_output = self.layer2(layer1_output)
layer2_output = layer2_output[:,-1,:]#取出一个batch中每个句子最后一个单词的输出向量即该句子的语义向量!!!!!!!!!
return layer2_output
#建立小模型
model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4)
#建立大模型(此处仍然使用LSTM代替,可以使用训练好的BERT等复杂模型)
model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4)
#设置输入数据,此处只使用随机生成的数据代替
inputs = torch.randn(4,6,2)
true_label = torch.tensor([0,1,0,0])
#生成dataset
dataset = TensorDataset(inputs,true_label)
#生成dataloader
sampler = SequentialSampler(inputs)
dataloader = DataLoader(dataset = dataset,sampler = sampler,batch_size = 2)
loss_fun = CrossEntropyLoss()
criterion = nn.KLDivLoss()#KL散度
optimizer = torch.optim.SGD(model_student.parameters(),lr = 0.1,momentum = 0.9)#优化器,优化器中只传入了学生模型的参数,因此此处只对学生模型进行参数更新,正好实现了教师模型参数不更新的目的
for step,batch in enumerate(dataloader):
inputs = batch[0]
labels = batch[1]
#分别使用学生模型和教师模型对输入数据进行计算
output_student = model_student(inputs)
output_teacher = model_teacher(inputs)
r"""
#建立小模型
model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4)
#建立大模型(此处仍然使用LSTM代替,可以使用训练好的BERT等复杂模型)
model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4)
"""
#计算学生模型和真实标签之间的交叉熵损失函数值
loss_hard = loss_fun(output_student,labels)
#计算学生模型预测结果和教师模型预测结果之间的KL散度
loss_soft = criterion(output_student,output_teacher)
loss = 0.9*loss_soft + 0.1*loss_hard
optimizer.zero_grad()
loss.backward()
optimizer.step()
学生模型的结构相对小一点
model_student = model(input_dim = 2,hidden_dim = 8,output_dim = 4)
教学模型的结构相对大一点
model_teacher = model(input_dim = 2,hidden_dim = 16,output_dim = 4)
最终求得的内容为model_student和model_teacher的KL散度的损失内容
loss = 0.9*loss_soft + 0.1*loss_hard
感觉这里的自蒸馏有点像rdropout的操作,使用两个模型,然后求两个模型之间的KL散度差
import torch.nn.functional as F
# define your task model, which outputs the classifier logits
model = TaskModel()
def compute_kl_loss(self, p, q pad_mask=None):
p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
# pad_mask is for seq-level tasks
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)
# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss
# keep dropout and forward twice
logits = model(x)
logits2 = model(x)
# cross entropy loss for classifier
ce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))
kl_loss = compute_kl_loss(logits, logits2)
# carefully choose hyper-parameters
loss = ce_loss + α * kl_loss
这里的rdropout也是跑模型两次,与自蒸馏的区别在于自蒸馏是跑不同大小的模型两次,而rdropout是跑同样的模型两次,同样都是使用kl散度作为损失内容的参数,都是跑两次模型求两次模型的误差。