深度学习——(12)Knowledge distillation(Demo)
原本昨天晚上要写的,但是奈何手中有更紧迫的任务需要做,所以自己还没有实战,昨天看到了一个简单的demo,自己写了一部分注释,希望对大家有帮助。等我把手头的活干完,再来接着详细说
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 24 09:23:35 2022
@author: Lenovo
"""
from torchvision.models.resnet import resnet18, resnet50
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn
resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth"
resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth"
img_dir = "/data/cifar10/"
def create_data(img_dir):
'''
根据img_dir中的图片创建dataloader,定义transformer,batchsize,并行数
其实都是可以定义在前面的参数,但是作者在这个地方是写死的,可以作为函数中的一个变量来进行定义(每一次都改很麻烦的)
batch_size,num_work
'''
dataset = dst.CIFAR10
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
train_transform = transforms.Compose([
transforms.Pad(4, padding_mode='reflect'),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# define data loader
train_loader = torch.utils.data.DataLoader(
dataset(root=img_dir,
transform=train_transform,
train=True,
download=True),
batch_size=512, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
dataset(root=img_dir,
transform=test_transform,
train=False,
download=True),
batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, test_loader
def load_checkpoint(net, pth_file, exclude_fc=False):
'''
加载模型权重
:net(Module) 定义的网络结构
:pth_file 权重路径
:exclude_fc 是否去除全连接层,如果exclude_fc为True,表示网络加载的时候删除最后全连接层,否则表示保持完整网络不做删除
'''
if exclude_fc:
model_dict = net.state_dict()
pretrain_dict = torch.load(pth_file)
new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
model_dict.update(new_dict)
net.load_state_dict(model_dict, strict=True)
else:
pretrain_dict = torch.load(pth_file)
net.load_state_dict(pretrain_dict, strict=True)
def accuracy(output, target, topk=(1,)):
"""
计算准确率
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class KD_loss(nn.Module):
'''
简单的知识蒸馏部分
注:命名为KD_loss但其实是个model,所以继承了Module
核心就是计算loss,所以在forward部分直接定义为计算loss(student模型和teacher模型之间的loss)
'''
def __init__(self, T):
super(KD_loss, self).__init__()
self.T = T
def forward(self, out_s, out_t):
'''
前向传播
计算student网络的输出和teacher网络的输出之间的KL散度,此处是teacher网络知道student网络
所以student网络在前,且为避免KL散度计算出负值,第一个参数需要是对数概率,所以使用log_softmax
'''
loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
F.softmax(out_t / self.T, dim=1),
reduction='batchmean') * self.T * self.T
return loss
def test(net, test_loader):
'''
相当于一般的predict过程
'''
prec1_sum = 0
prec5_sum = 0
net.eval()
for i, (img, target) in enumerate(test_loader, start=1):
# print(f"batch: {i}")
img = img.cuda()
target = target.cuda()
with torch.no_grad():
out = net(img)
prec1, prec5 = accuracy(out, target, topk=(1, 5))
prec1_sum += prec1
prec5_sum += prec5
# print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}")
def train(net_s, net_t, train_loader, test_loader):
'''
训练过程
'''
opt = Adam(net_s.parameters(), lr=0.0001)
net_s.train()
net_t.eval()
for epoch in range(100):
for step, batch in enumerate(train_loader):
opt.zero_grad()
image, target = batch
image = image.cuda()
target = target.cuda()
out_s, out_t = net_s(image), net_t(image)
loss_init = CrossEntropyLoss()(out_s, target) # 先计算student模型的结果和真正的(硬label)之间的loss
loss_kd = KD_loss(T=4)(out_s, out_t) # 计算student模型生成的结果(概率分布状况)和teacher模型生成的结果之间的KL散度
loss = loss_init + loss_kd # 最后的loss定义为两个分布之间的差异loss以及由student模型预测的label和真正label之间的loss
# prec1, prec5 = accuracy(predict, target, topk=(1, 5))
# print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
loss.backward()
opt.step()
print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
test(net_s, test_loader)
torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth')
def main():
net_t = resnet50(num_classes=10) # 将teacher模型定义为resnet50
net_s = resnet18(num_classes=10) # teacher模型定义为resnet18
net_t = net_t.cuda()
net_s = net_s.cuda()
load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
# for name, value in net.named_parameters():
# if 'fc' not in name:
# value.requires_grad = False
train_loader, test_loader = create_data(img_dir)
train(net_s, net_t, train_loader, test_loader)
# test(net, test_loader)
if __name__ == "__main__":
main()
注 1 :上面的模型在有GPU,装了cuda的机子上使用,在windows上使用时需要将上面的.cuda()
都去掉,或者在前面加device 判断,直接.device()
。这里我就不给大家改了,使用的话自取,若有问题,欢迎讨论。
感 1 :最近因为要把以前训练好的模型权重作为新的模型输入,将几个模型整合在一起考虑更多的特征信息,所以看了知识蒸馏,觉得这个模型,咦,有点意思!其实之前课题起步阶段看过一点,当时是一篇文献中好像有个方法叫DINO,那个时候初次认识知识蒸馏,后来想着下来看一下这篇引用的文献,结果一拖再拖,到前几天又提起,所以临时做的功课。
感 2 :最近所有事情都来了,专利需要改稿,方法需要再优化一些,上周末有了新的思路,想要try一下,只给我两周时间,如果不可以直接pass,一共还有三个step没有处理,但是现在step1还刚有了雏形,文章背景还没写。加油吧,羊。过了这段时间应该会轻松一点。