知识蒸馏的开山之作:Distilling the Knowledge in a Neural Network
什么是知识?
we tend to identify the knowledge in a trained model with the learned parameter values.
知识是从训练模型中学到的参数值。
什么是蒸馏?
which we call “distillation” to transfer the knowledge from the cumbersome model to a small model that is more suitable for deployment.
将知识从复杂的大模型(教师模型)迁移到适合部署的小模型(学生模型)。
知识蒸馏损失函数构成
Demo代码实现
利用MNIST数据集,从零实现一个蒸馏模型。
导入工具包
# 导入module
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
设置cuda环境
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用cuDNN 加速卷积运算
torch.backends.cudnn.benchmark = True
# 设置随机数种子,便于复现
torch.manual_seed(2022)
加载数据集
# 载入MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='mnist/dataset/',
train = True,
transform = transforms.ToTensor(),
download = True)
test_dataset = torchvision.datasets.MNIST(root='mnist/dataset/',
train = False,
transform = transforms.ToTensor(),
download = True)
# 生成dataloader
train_loader = DataLoader(dataset = train_dataset, batch_size = 32, shuffle=True)
test_loader = DataLoader(dataset = test_dataset, batch_size = 32, shuffle=True)
从零训练教师模型
# 定义教师模型
class TeacherModel(nn.Module):
def __init__(self, hidden_dim=1024, num_classes=10):
super(TeacherModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
model = TeacherModel()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
summary(model)
可以看出,教师模型总参数量为:186w
# 训练教师模型
epochs = 10
for epoch in range(epochs):
model.train()
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
# 前向预测
preds = model(data)
loss = criterion(preds, targets)
# 后向传播,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
num_correct, num_samples = 0, 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
print('Epoch: {}\t Accuracy: {:.4f}'.format(epoch+1, acc))
知识蒸馏训练学生模型
# 定义学生模型
class StudentModel(nn.Module):
def __init__(self, in_channels=1, hidden_dim=128, num_classes=10):
super(StudentModel, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
s_model = StudentModel(hidden_dim=10)
s_model = s_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(s_model.parameters(), lr = 1e-3)
summary(s_model)
学生模型总参数量为:8070,对模型进行极大地压缩。
# 准备预训练好的 教师模型
teacher_model.eval()
# 准备新的学生模型
s_model = StudentModel(hidden_dim=10)
s_model = s_model.to(device)
s_model.train()
# 蒸馏温度
temp = 7
# hard loss
hard_loss = nn.CrossEntropyLoss()
# hard loss权重
alpha = 0.4
# soft loss
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(s_model.parameters(), lr = 1e-3)
# 蒸馏模型
epochs = 5
for epoch in range(epochs):
s_model.train()
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
# 教师模型 预测
with torch.no_grad():
teacher_preds = teacher_model(data)
# 学生模型 预测
student_preds = s_model(data)
# hard loss
student_loss = hard_loss(student_preds, targets)
# soft loss
distillation_loss = soft_loss(F.softmax(student_preds/temp, dim = 1),
F.softmax(teacher_preds/temp, dim = 1))
# 加权loss
loss = alpha*student_loss + (1-alpha)*distillation_loss*(temp**2)
# 后向传播,优化权重
optimizer.zero_grad()
loss.backward()
optimizer.step()
s_model.eval()
num_correct, num_samples = 0, 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = s_model(x)
predictions = preds.max(1).indices
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
s_model.train()
print('Epoch: {}\t Accuracy: {:.4f}'.format(epoch+1, acc))