知识蒸馏代码实现(内容:知识蒸馏模型识别MNIST手写数字体,自定义MLP网络做为教师和学生网络,训练结果保存在log文件中,不同蒸馏损失计算方法得到的结果对比)

一.使用知识蒸馏对MNIST手写体进行分类:

1.1 utils.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

def load_data():
    # 载入MNIST训练集
    train_dataset = torchvision.datasets.MNIST(
        root = "../datasets/",
        train=True,
        transform=transforms.ToTensor(),
        download=True
    )

    # 载入MNIST测试集
    test_dataset = torchvision.datasets.MNIST(
        root = "../datasets/",
        train=False,
        transform=transforms.ToTensor(),
        download=True
    )

    # 生成训练集和测试集的dataloader
    train_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)
    test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)
    return train_dataloader,test_dataloader



class Logger(object):
    def __init__(self, filename):
        self.filename = filename
        f = open(self.filename + ".log", "a")
        f.close()

    def write(self, message):
        f = open(self.filename + ".log", "a")
        f.write(message)
        f.close()

1.2 models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,num_classes)
        self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.
    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)
        return x

class StudentModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,20)
        self.fc2 = nn.Linear(20,20)
        self.fc3 = nn.Linear(20,num_classes)
    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.relu(x)

        x = self.fc3(x)
        return x

1.3 train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as F

def train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device,logger):
    # ----------------------开始计时-----------------------------------
    start_time = time.time()

    # 设置参数开始训练
    best_acc, best_epoch = 0, 0
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        # 训练集上训练模型权重
        for data, targets in tqdm.tqdm(train_dataloader):
            # 把数据加载到GPU上
            data = data.to(device)
            targets = targets.to(device)

            # 前向传播
            preds = model(data)
            loss = criterion(preds, targets)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 测试集上评估模型性能
        model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in test_dataloader:
                x = x.to(device)
                y = y.to(device)
                preds = model(x)
                predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)
            acc = (num_correct / num_samples).item()
            if acc > best_acc:
                best_acc = acc
                best_epoch = epoch
                # 保存模型最优准确率的参数
                torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")
        model.train()

        logger.write('Epoch:{}\t Accuracy:{:.4f} \t'.format(epoch + 1, acc),
                 f'student_hard_loss={student_hard_loss}\t,ditillation_loss={ditillation_loss}\t,loss={loss}' + '\n')
        if epoch % 10 == 0 and epoch != 0:
            logger.write(
                f"------------------------当前最优准确率为:{best_acc},所在的epoch为:{best_epoch}--------------------\n")
    logger.write(
        f'最优准确率为{best_acc},所在的epoch为:{best_epoch},最优参数已经保存到:weights/{model_name}_best_acc_params.pth' + '\n')

# -------------------------结束计时------------------------------------
    end_time = time.time()
    run_time = end_time - start_time
    # 将输出的秒数保留两位小数
    if int(run_time) < 60:
        print(f'训练用时为:{round(run_time, 2)}s')
    else:
        print(f'训练用时为:{round(run_time / 60, 2)}minutes')

def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device,logger):
    # -------------------------------------开始计时--------------------------------
    start_time = time.time()

    # 定以损失函数
    hard_loss = nn.CrossEntropyLoss()
    soft_loss = nn.KLDivLoss(reduction="batchmean")

    # 定义优化器
    optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)

    best_acc,best_epoch = 0,0
    for epoch in range(epochs):
        student_model.train()
        # 训练集上训练模型权重
        for data,targets in tqdm.tqdm(train_dataloader):
            # 把数据加载到GPU上
            data = data.to(device)
            targets = targets.to(device)

            # 教师模型预测
            with torch.no_grad():
                teacher_preds = teacher_model(data)
            # 学生模型预测
            student_preds = student_model(data)
            # 计算hard_loss
            student_hard_loss = hard_loss(student_preds,targets)


            # chatgpt版Loss
            soft_student_outputs = F.log_softmax(student_preds / temp, dim=1)
            soft_teacher_outputs = F.softmax(teacher_preds/temp,dim=1)

            ditillation_loss = soft_loss(soft_student_outputs,soft_teacher_outputs)
            loss = alpha * student_hard_loss + (1-alpha) * temp * temp * ditillation_loss


            # 反向传播,优化权重
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        #测试集上评估模型性能
        student_model.eval()
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x,y in test_dataloader:
                x = x.to(device)
                y = y.to(device)
                preds = student_model(x)
                predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引
                num_correct += (predictions ==y).sum()
                num_samples += predictions.size(0)
            acc = (num_correct/num_samples).item()
            if acc>best_acc:
                best_acc = acc
                best_epoch = epoch
                # 保存模型最优准确率的参数
                torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")
        student_model.train()

        logger.write('Epoch:{}\t Accuracy:{:.4f} \t'.format(epoch+1,acc),f'student_hard_loss={student_hard_loss}\t,ditillation_loss={ditillation_loss}\t,loss={loss}' + '\n')

        if epoch %10==0 and epoch!=0:
            logger.write(f"------------------------当前最优准确率为:{best_acc},所在的epoch为:{best_epoch}--------------------\n")

    logger.write(f'最优准确率为{best_acc},所在的epoch为:{best_epoch},最优参数已经保存到:weights/{model_name}_best_acc_params.pth' + '\n')

    # --------------------------------结束计时----------------------------------
    end_time = time.time()
    run_time = end_time - start_time
    # 将输出的秒数保留两位小数
    if int(run_time) < 60:
        logger.write(f'训练用时为:{round(run_time, 2)}s' + '\n')
    else:
        logger.write(f'训练用时为:{round(run_time / 60, 2)}minutes' + '\n')

1.3 训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import utils
import train_tools
from datetime import datetime


# 设置参数
dataset_name = "MNIST"
lr = 1e-4
model_name = 'teacher'
epochs = 50
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
random_seed=0 # 手动设置随机种子


log_name = (
    "logs/"
    + model_name
    + "_"
    + dataset_name
    + "_"
    + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)

# loguru.logger.info(f"lr={lr},epochs={epochs},alpah={alpha},temp={temp},manual_seed({random_seed})")

# 定义logger
logger = utils.Logger(log_name)
logger.write(f"lr={lr},epochs={epochs},alpah={alpha},temp={temp},manual_seed({random_seed})")


# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = utils.load_data()

# 定义教师模型
model = models.TeacherModel()
model = model.to(device)

# 打印模型的参数
logger.write(str(summary(model))+'\n')

train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device,logger)
最优准确率的epoch为9,值为:0.9868999719619751

1.4 用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import utils
import models
import train_tools

# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = utils.load_data()

# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))

# 定义参数并开始训练
epochs = 50
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)```

```python
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

1.5 用知识蒸馏的方法训练student model

使用该方法时,教师网络必须首先训练好,加载使用最佳的教师网络参数。

import torch
import train_tools
import models
import utils
from torchinfo import summary #用来可视化的
from datetime import datetime
import loguru

# 设置参数
model_name = 'distill_student'
dataset_name = "MNIST"
# lr = 0.0001
lr = 1e-4
epochs = 50
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
random_seed=0 # 手动设置随机种子


log_name = (
    "logs/"
    + model_name
    + "_"
    + dataset_name
    + "_"
    + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)

# loguru.logger.info(f"lr={lr},epochs={epochs},alpah={alpha},temp={temp},manual_seed({random_seed})")

# 定义logger
logger = utils.Logger(log_name)
logger.write(f"lr={lr},epochs={epochs},alpah={alpha},temp={temp},manual_seed({random_seed})")

# 设置随机数种子
torch.manual_seed(random_seed)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 加载数据
train_dataloader,test_dataloader = utils.load_data()


# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()

# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()

# 查看模型参数
logger.write(str(summary(student_model))+'\n')

# 开始训练
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device,logger)

在这里插入图片描述

二.使用不同的计算蒸馏损失的方法得到的结果对比

2.1 chatgpt版(相对于下面两个版本,这个是最好的)

在 distill_train()函数中的loss部分改为如下:

soft_student_outputs = F.log_softmax(student_preds / temp, dim=1)
soft_teacher_outputs = F.softmax(teacher_preds/temp,dim=1)
ditillation_loss = soft_loss(soft_student_outputs,soft_teacher_outputs)
loss = alpha * student_hard_loss + (1-alpha) * temp * temp * ditillation_loss

当epoch=10时:

最优准确率的epoch为9,值为:0.9286999702453613,
训练用时为:2.1minutes

![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/ae1c768f31384418b2767fbf2cf54f39.png

当epoch=50时:

在这里插入图片描述

2.2 同济子豪兄版loss函数

# 同济子豪兄版
   # 计算蒸馏后的预测结果及soft_loss
   ditillation_loss = soft_loss(
       F.softmax(student_preds/temp,dim=1),
       F.softmax(teacher_preds/temp,dim=1)
   )

   # temp的平方乘ditillation_loss
   loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss

在这里插入图片描述loss会出现负数的情况,效果不好。

2.3 文心一言版

import torch  
import torch.nn.functional as F  
  
def distillation_loss(student_logits, teacher_logits, temperature):  
    """  
    蒸馏损失函数  
    :param student_logits: 学生模型的预测结果,形状为 (batch_size, num_classes)  
    :param teacher_logits: 教师模型的预测结果,形状为 (batch_size, num_classes)  
    :param temperature: 温度参数,用于控制预测结果的软硬程度  
    :return: 蒸馏损失  
    """  
    # 计算对数概率  
    student_probs = F.softmax(student_logits / temperature, dim=1)  
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)  
  
    # 计算KL散度  
    kl_divergence = F.kl_div(student_probs.log(), teacher_probs, reduction='batchmean') * (temperature ** 2)  
  
    # 返回KL散度和温度参数的乘积作为蒸馏损失  
    return kl_divergence * temperature
最优准确率的epoch为9,值为:0.9289999604225159,
训练用时为:2.14minutes

发现文心一言版的,hard_loss和蒸馏损失并不在一个量级。
在这里插入图片描述
也没有出现Loss为负数的情况。

当epoch设置为50时:

最优准确率的epoch为49,值为:0.9585999846458435,
训练用时为:11.09minutes

其它开源的知识蒸馏算法如下:

三.其它现成的知识蒸馏代码

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

四.视频教程资料

神经网络知识蒸馏 Knowledge Distillation

https://www.bilibili.com/video/BV1s7411h7K2/?spm_id_from=333.337.search-card.all.click&vd_source=ebc47f36e62b223817b8e0edff181613

在这里插入图片描述

  • 12
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是使用ResNet网络模型实现MNIST手写数字识别代码示例: ```python import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense from tensorflow.keras.models import Model # 加载MNIST数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据预处理:将像素值缩放到0到1之间,并增加一维通道数 x_train = x_train.astype('float32') / 255.0 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) x_test = x_test.astype('float32') / 255.0 x_test = x_test.reshape(x_test.shape[0], 28, 28, 1) # 定义ResNet网络模型 def residual_block(inputs, filters, strides=1): shortcut = inputs # 第一个卷积层 x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(inputs) x = BatchNormalization()(x) x = ReLU()(x) # 第二个卷积层 x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x) x = BatchNormalization()(x) # 如果输入和输出的维度不同,则使用1x1卷积调整维度 if shortcut.shape[-1] != filters: shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='same')(shortcut) shortcut = BatchNormalization()(shortcut) # 将残差块的输出与输入相加,构成下一层的输入 x = Add()([x, shortcut]) x = ReLU()(x) return x def ResNet(input_shape=(28, 28, 1), num_classes=10): inputs = Input(shape=input_shape) # 第一层卷积 x = Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs) x = BatchNormalization()(x) x = ReLU()(x) # 残差块组1 x = residual_block(x, filters=64, strides=1) x = residual_block(x, filters=64, strides=1) x = residual_block(x, filters=64, strides=1) # 残差块组2 x = residual_block(x, filters=128, strides=2) x = residual_block(x, filters=128, strides=1) x = residual_block(x, filters=128, strides=1) # 残差块组3 x = residual_block(x, filters=256, strides=2) x = residual_block(x, filters=256, strides=1) x = residual_block(x, filters=256, strides=1) # 残差块组4 x = residual_block(x, filters=512, strides=2) x = residual_block(x, filters=512, strides=1) x = residual_block(x, filters=512, strides=1) # 全局平均池化 x = GlobalAveragePooling2D()(x) # 全连接层 x = Dense(num_classes, activation='softmax')(x) model = Model(inputs=inputs, outputs=x) return model # 创建ResNet模型 model = ResNet(num_classes=10) # 编译模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练模型 model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test)) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

computer_vision_chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值