深度学习——(12)Knowledge distillation(Demo)

本文介绍了一个使用PyTorch实现的知识蒸馏Demo,通过ResNet18和ResNet50模型进行知识转移。首先定义数据加载器,然后加载预训练权重,接着实现知识蒸馏损失函数KD_loss,最后进行训练并测试模型性能。作者分享了学习知识蒸馏的心得,并提到近期工作繁忙,但依然坚持学习和实践。
摘要由CSDN通过智能技术生成

深度学习——(12)Knowledge distillation(Demo)

原本昨天晚上要写的,但是奈何手中有更紧迫的任务需要做,所以自己还没有实战,昨天看到了一个简单的demo,自己写了一部分注释,希望对大家有帮助。等我把手头的活干完,再来接着详细说

# -*- coding: utf-8 -*-
"""
Created on Sat Sep 24 09:23:35 2022

@author: Lenovo
"""

from torchvision.models.resnet import resnet18, resnet50
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn

resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth"   
resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth"
img_dir = "/data/cifar10/"


def create_data(img_dir): 
    '''
    根据img_dir中的图片创建dataloader,定义transformer,batchsize,并行数
    其实都是可以定义在前面的参数,但是作者在这个地方是写死的,可以作为函数中的一个变量来进行定义(每一次都改很麻烦的)
    batch_size,num_work
    '''
    dataset = dst.CIFAR10
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=train_transform,
                train=True,
                download=True),
        batch_size=512, shuffle=True, num_workers=4, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=test_transform,
                train=False,
                download=True),
        batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, test_loader

def load_checkpoint(net, pth_file, exclude_fc=False): 
    '''
    加载模型权重
    :net(Module) 定义的网络结构
    :pth_file 权重路径
    :exclude_fc 是否去除全连接层,如果exclude_fc为True,表示网络加载的时候删除最后全连接层,否则表示保持完整网络不做删除
    '''
    if exclude_fc:
        model_dict = net.state_dict()
        pretrain_dict = torch.load(pth_file)
        new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
        model_dict.update(new_dict)
        net.load_state_dict(model_dict, strict=True)
    else:
        pretrain_dict = torch.load(pth_file)
        net.load_state_dict(pretrain_dict, strict=True)


def accuracy(output, target, topk=(1,)):
    """
    计算准确率
    """
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class KD_loss(nn.Module):
    '''
    简单的知识蒸馏部分
    注:命名为KD_loss但其实是个model,所以继承了Module
    核心就是计算loss,所以在forward部分直接定义为计算loss(student模型和teacher模型之间的loss)
    '''
    def __init__(self, T):
        super(KD_loss, self).__init__()
        self.T = T

    def forward(self, out_s, out_t):
        '''
        前向传播
        计算student网络的输出和teacher网络的输出之间的KL散度,此处是teacher网络知道student网络
        所以student网络在前,且为避免KL散度计算出负值,第一个参数需要是对数概率,所以使用log_softmax
        '''
        loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
                        F.softmax(out_t / self.T, dim=1),
                        reduction='batchmean') * self.T * self.T

        return loss


def test(net, test_loader):
    '''
    相当于一般的predict过程
    '''
    prec1_sum = 0
    prec5_sum = 0
    net.eval()
    for i, (img, target) in enumerate(test_loader, start=1):
        # print(f"batch: {i}")
        img = img.cuda()
        target = target.cuda()

        with torch.no_grad():
            out = net(img)
        prec1, prec5 = accuracy(out, target, topk=(1, 5))
        prec1_sum += prec1
        prec5_sum += prec5
        # print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
    print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}")


def train(net_s, net_t, train_loader, test_loader):
    '''
    训练过程
    '''
    opt = Adam(net_s.parameters(), lr=0.0001)
    net_s.train()
    net_t.eval()
    for epoch in range(100):
        for step, batch in enumerate(train_loader):
            opt.zero_grad()
            image, target = batch
            image = image.cuda()
            target = target.cuda()
            out_s, out_t = net_s(image), net_t(image)
            loss_init = CrossEntropyLoss()(out_s, target) # 先计算student模型的结果和真正的(硬label)之间的loss
            loss_kd = KD_loss(T=4)(out_s, out_t)  # 计算student模型生成的结果(概率分布状况)和teacher模型生成的结果之间的KL散度
            loss = loss_init + loss_kd # 最后的loss定义为两个分布之间的差异loss以及由student模型预测的label和真正label之间的loss
            # prec1, prec5 = accuracy(predict, target, topk=(1, 5))
            # print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
            loss.backward()
            opt.step()
        print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
        test(net_s, test_loader)

    torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth')


def main():
    net_t = resnet50(num_classes=10)    # 将teacher模型定义为resnet50 
    net_s = resnet18(num_classes=10)    # teacher模型定义为resnet18
    net_t = net_t.cuda()
    net_s = net_s.cuda()
    load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
    load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
    # for name, value in net.named_parameters():
    #     if 'fc' not in name:
    #         value.requires_grad = False

    train_loader, test_loader = create_data(img_dir)
    train(net_s, net_t, train_loader, test_loader)
    # test(net, test_loader)


if __name__ == "__main__":
    main()

注 1 :上面的模型在有GPU,装了cuda的机子上使用,在windows上使用时需要将上面的.cuda()都去掉,或者在前面加device 判断,直接.device()。这里我就不给大家改了,使用的话自取,若有问题,欢迎讨论。
感 1 :最近因为要把以前训练好的模型权重作为新的模型输入,将几个模型整合在一起考虑更多的特征信息,所以看了知识蒸馏,觉得这个模型,咦,有点意思!其实之前课题起步阶段看过一点,当时是一篇文献中好像有个方法叫DINO,那个时候初次认识知识蒸馏,后来想着下来看一下这篇引用的文献,结果一拖再拖,到前几天又提起,所以临时做的功课。
感 2 :最近所有事情都来了,专利需要改稿,方法需要再优化一些,上周末有了新的思路,想要try一下,只给我两周时间,如果不可以直接pass,一共还有三个step没有处理,但是现在step1还刚有了雏形,文章背景还没写。加油吧,羊。过了这段时间应该会轻松一点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柚子味的羊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值