知识蒸馏:《Distilling the Knowledge in a Neural Network》算法介绍及PyTorch代码实例

目录

一、摘要

二、 蒸馏算法

三、代码

四、References        


一、摘要

        提高几乎任何机器学习算法性能的一个非常简单的方法就是在相同的数据上训练许多不同的模型,然后平均它们的预测,或者对模型进行集成然后投票(vote),即多模型集成可以显著提升机器学习性能。很不幸,使用整个集成模型进行预测是很麻烦的,而且可能计算成本太高,若部署到用户群体非常庞大的情景下,每一个用户所产生的的输入都要在整个集成模型上运行一次,这对算力的要求太高。《Model Compression》这篇文献发现:将集成学习模型学习到的知识压缩到单个模型中后,模型部署就会变得容易许多。本文继承了这种思想,并提出了一种新的模型压缩方法——“知识蒸馏”(Knowledge Distilling, KD)。该方法在MNIST数据集上取得了令人惊讶的结果,并且本文展示了可以通过将一个集成模型中的知识蒸馏到一个单一的模型中,可以显著地改进一个已经大规模商业应用的语音模型的性能。本文还提出了一种由一个通用模型(full models)和许多专用模型(specialist models)构成的模型集成范式,后者用以识别通用模型容易混淆的细粒度类别。与以前专家模型(expert models)的范式不同,专用模型可以快速地并行训练。

         Many insects have a larval form that is optimized for extracting energy and nutrients from the environment and a completely different adult form that is optimized for the very different requirements of traveling and reproduction.

        在大规模机器学习场景下,无论是训练阶段还是部署阶段,我们通常使用非常相似的模型,尽管训练和部署的需求并不相同:模型训练必须从规模非常大且高度冗余的数据集中提取特征,但它不需要实时操作,并且允许使用大量计算的计算资源。然而,模型部署到具有大量用户的场景下时,对延迟和计算资源有着更严格的要求。训练得到的模型往往是非常庞大的,或是采用集成学习得到,或是采用正则化手段训练的单一大模型。一旦繁琐/庞大的模型被训练好,我们就可以使用一种不同的训练手段,称之为“蒸馏”,将大模型学到的知识迁移到一个更适合部署的小模型,前人的工作已经证明了这一点。

        但是,如何定义并量化“知识”(Knowledge)这个概念是一大难点。通常我们认为模型学习到的参数代表了知识,但这是非常片面的,因为大模型和小模型的结构、参数有着明显的差异,将大模型的参数迁移/复制到小模型上来更无从谈起。教师网络(即大模型)的输出预测概率中各类别概率的相对大小隐式地包含了“知识”,即使是对于非正确类别的那些概率而言,它们的相对大小包含着非常重要的信息。例如,一辆车的真实标签是宝马,其被错误地识别为垃圾车的概率很小,但是其被认为是垃圾车的概率显然要远远大于其被认为是胡萝卜的概率。想要让学生网络(即小模型)在测试集上拥有优秀的泛化性能,就需要知道“知识”如何被定义并量化,这样才能让学生网络学习与教师网络相同的“知识”。

        一种方法是采用“Soft Targets”来表示知识,即将教师网络产生的各个类别的概率作为soft targets来训练学生网络。Soft targets相较于hard tagets而言拥有更高的熵,那么包含的信息也就越丰富,因此在训练学生网络时可以使用更少的数据和更大的学习率。

        上面这段是由论文introduction第4段的本意总结的,其中有几个比较令人困惑的点,写一下个人观点,若有不当之处,还望批评指正:(1)为什么soft targets的熵更高:熵表征系统混乱程度,hard targets这种非0即1的表示方法显然具有极高的确定性,因此熵低,而soft targets展示出了相对概率大小(如上面宝马的例子),不确定性程度更高,熵更高;(2)为什么熵高就包含更多的信息:关于熵的大小和信息量的大小之间的关系众说纷纭,有说熵越大信息量越大的,也有说熵越小信息量越小的,我没有学过信息论,但我认为他们都忽略了一个定语,即什么样的信息,这样描述或许会更容易理解:“熵越大,系统混乱程度越大,其包含的不确定性信息越多,包含的确定性信息越少”,这里放一个知乎,他说“熵减”与信息量的大小才是相呼应的,而非是熵,熵越大,信息量到底是越大还是越小?-知乎;(3)为什么使用soft targets后训练学生网络就可以使用更大的学习率:我也不知道,玄学。

        这几点都不是本文研究的重点,所以不必太过在意,记住就好。

        更新,关于第(2)又有新发现:信息熵越大,信息量到底是越大还是越小? - 知乎,这个是从熵的计算方法的角度阐述的,他提到的熵权法和soft targets有神似之处,我觉得可以作为正解。   

        另外,概率的绝对大小也是很重要的,因为过小的logits经过sofmax之后得到的概率会接近于0,这就导致这个概率在交叉熵中几乎得不到体现。在前人的工作中,他们采用softmax层之前的logits作为targets,使用均方误差对教师网络和学生网络的logits做损失,以此来规避经过softmax后得到的概率过小的问题。本文提出了更加通用的方法,叫做“蒸馏”,该方法通过提高softmax的温度T来得到恰当的soft targets,然后在训练学生网络来拟合该soft targets时采用相同的温度T。

6856eef514a0476399382117f7469ad0.png

二、 蒸馏算法

        神经网络通常使用softmax层将logits转换为概率,“蒸馏”将softmax中引入一个温度T来s生更加soft的概率分布,如上式所示,且T越高,所产生的概率分布越soft。在训练阶段,教师网络和学生网络采用相同的温度T进行蒸馏;在推理阶段,训练好的学生网络使用T=1即默认的softmax进行推理。

9bc3b52548ae4d70a64420fd30d598d8.png

        损失函数方面,总损失=λ·hardloss+(1-λ)T²·softlossSoft Loss又称Distillation Loss,它是将教师网络经过温度T=t蒸馏后的输出概率当做labels,即soft labels/targets,将学生网络经过温度T=t蒸馏后的输出概率当做预测值,即soft predictions,二者进行交叉熵损失作为Soft LossHard Loss又称Student Loss,它是将学生网络经过T=1蒸馏(即默认的softmax)后的输出概率作为预测值,即hard predictions,将输入图像的one-hot编码的hard label作为真实值,二者进行交叉熵损失计算作为Hard Loss。由于Soft Loss对logits的偏导数的magnitude大约是Hard Loss对logits的偏导数的1/T² ,因此Soft Loss前面乘一个T²,这样才能保证soft target和hard target贡献的梯度量基本一致。

e4e2a0b026164772b37e975635f3fc54.png

三、代码

# 代码(1)
"""使用ResNet及CIFAR10进行实验,GPU性能高的同学可以用这段代码"""
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 随机种子和cuda配置
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # 使用cudnn加速卷积运算

# 加载数据集
train_dataset = torchvision.datasets.CIFAR10(root='dataset/', train=True,
                                           transform=transforms.ToTensor(), download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='dataset/', train=False,
                                          transform=transforms.ToTensor(), download=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)


# 创建教师模型
model = torchvision.models.resnet34(pretrained=False)  # 实例化
model = model.to(device)  # 指定到device

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 3
for epoch in range(epochs):
    model.train()  # 训练模式
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)  # 将data指认到device
        targets = targets.to(device)  # 将targets指认到device
        preds = model(data)  # 前向传播得到预测结果
        loss = criterion(preds, targets)  # 交叉熵损失

        optimizer.zero_grad()  # 清空梯度信息
        loss.backward()  # 损失反向传播
        optimizer.step()  # 对网络参数进行优化

    # 进入测试模式
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():  # 固定所有参数的梯度为0,因为测试阶段不需要进行优化
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)  # 前向传播得到测试结果,preds为一个向量
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()

    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

teacher_model = model.to(device)

# 这部分仅仅是为了展示单独训练一个学生模型时的效果,与采用蒸馏训练对比一下
model = torchvision.models.resnet18(pretrained=False)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 3
for epoch in range(epochs):
    model.train()
    # 在训练集上训练
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)
        preds = model(data)
        loss = criterion(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()

    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

student_model_scratch = model

"""------------------------------蒸 馏----------------------------------"""
teacher_model.eval()  # 准备预训练好的教师模型
stu_ditillation_model = torchvision.models.resnet18()  # 准备新的学生模型
stu_ditillation_model = stu_ditillation_model.to(device)
stu_ditillation_model.train()

temp = 7  # 蒸馏温度
hard_loss = nn.CrossEntropyLoss()
alpha = 0.3  # hard_loss权重
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(stu_ditillation_model.parameters(), lr=1e-4)

epochs = 3
for epoch in range(epochs):
    # 训练集上训练学生模型的权重
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)
        with torch.no_grad(): # 教师模型预测
            teachers_preds = teacher_model(data)

        students_preds = stu_ditillation_model(data)

        # 损失函数
        students_loss = hard_loss(students_preds, targets)
        ditillation_loss = soft_loss(
            F.softmax(students_preds / temp, dim=1),
            F.softmax(teachers_preds / temp, dim=1)
        )
        loss = alpha * students_loss + (1 - alpha) * ditillation_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型性能
    stu_ditillation_model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)

            preds = stu_ditillation_model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
# 代码(2)
"""GPU性能一般的同学可以用这段代码"""
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 随机种子和cuda配置
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # 使用cudnn加速卷积运算

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='dataset/', train=True,
                                           transform=transforms.ToTensor(), download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='dataset/', train=False,
                                          transform=transforms.ToTensor(), download=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)


# 创建教师模型
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x


model = TeacherModel()  # 实例化
model = model.to(device)  # 指定到device

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 10
for epoch in range(epochs):
    model.train()  # 训练模式
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)  # 将data指认到device
        targets = targets.to(device)  # 将targets指认到device
        preds = model(data)  # 前向传播得到预测结果
        loss = criterion(preds, targets)  # 交叉熵损失

        optimizer.zero_grad()  # 清空梯度信息
        loss.backward()  # 损失反向传播
        optimizer.step()  # 对网络参数进行优化

    # 进入测试模式
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():  # 固定所有参数的梯度为0,因为测试阶段不需要进行优化
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)  # 前向传播得到测试结果,preds为一个向量
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()

    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

teacher_model = model.to(device)


class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

# 这部分仅仅是为了展示单独训练一个学生模型时的效果,与采用蒸馏训练对比一下
model = StudentModel()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 10
for epoch in range(epochs):
    model.train()
    # 在训练集上训练
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)
        preds = model(data)
        loss = criterion(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()

    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

student_model_scratch = model

"""------------------------------蒸 馏----------------------------------"""
teacher_model.eval()  # 准备预训练好的教师模型
stu_ditillation_model = StudentModel()  # 准备新的学生模型
stu_ditillation_model = stu_ditillation_model.to(device)
stu_ditillation_model.train()

temp = 7  # 蒸馏温度
hard_loss = nn.CrossEntropyLoss()
alpha = 0.3  # hard_loss权重
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(stu_ditillation_model.parameters(), lr=1e-4)

epochs = 10
for epoch in range(epochs):
    # 训练集上训练学生模型的权重
    for data, targets in tqdm(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)
        with torch.no_grad(): # 教师模型预测
            teachers_preds = teacher_model(data)

        students_preds = stu_ditillation_model(data)

        # 损失函数
        students_loss = hard_loss(students_preds, targets)
        ditillation_loss = soft_loss(
            F.softmax(students_preds / temp, dim=1),
            F.softmax(teachers_preds / temp, dim=1)
        )
        loss = alpha * students_loss + (1 - alpha) * ditillation_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型性能
    stu_ditillation_model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)

            preds = stu_ditillation_model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

四、References        

[1] Knowledge Distillation

[2] 知识蒸馏(Knowledge Distillation)_Law-Yao的博客-CSDN博客_只是蒸馏(墙裂安利)

[3] 【精读AI论文】知识蒸馏_哔哩哔哩_bilibili

  • 8
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
将神经网络中的知识进行提取,是一种将模型的信息转化为更为简洁和易于理解形式的过程。 神经网络是一种由许多神经元组成的复杂计算模型,它们通过学习和调整权重来解决各种问题。然而,神经网络通常具有大量的参数和复杂的结构,这使得它们难以解释和应用到其他领域。因此,我们需要一种方法来提取和总结神经网络中的知识,以便更好地理解和应用这些模型。 在进行神经网络知识提取时,有几种常见的方法。一种常见的方法是使用可视化技术,如热力图、激活图和网络结构图等,来可视化网络中不同层的活动模式。这些可视化技术能够帮助我们发现网络中的模式和特征,并从中推断出网络的知识。 另一种方法是使用特征提取技术,如卷积神经网络(CNN)的滤波器、自动编码器的隐藏层和循环神经网络(RNN)的隐状态等,来提取网络学习到的重要特征。这些重要特征可以帮助我们更好地理解网络学习到的信息,并将其应用到其他问题中。 此外,还有一种被称为知识蒸馏的技术,它通过训练一个较小的模型来提取大型模型中的知识。知识蒸馏通过引入目标函数和额外的训练策略,使小模型能够学习到大模型中的重要知识,并在不损失太多性能的情况下将其应用到实际问题中。 总而言之,提取神经网络中的知识是一项重要任务,它能够帮助我们更好地理解和应用这些复杂的模型。通过可视化、特征提取和知识蒸馏等方法,我们能够从神经网络中提取出有用的信息,并将其应用到其他领域或解决其他问题中。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Taylor不想被展开

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值