定义
知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单,有效,在工业界被广泛应用。
更简单的理解:用一个已经训练好的模型去“教”另一个模型去学习,这两个模型通常称为老师-学生模型。
用一个小例子来加深理解:
相关知识
pytorch中的损失函数:
Softmax:将一个数值序列映射到概率空间
# Softmax
import torch
import torch.nn.functional as F
# torch.nn是pytorch中自带的一个函数库,里面包含了神经网络中使用的一些常用函数,
# 如具有可学习参数的nn.Conv2d(),nn.Linear()和不具有可学习的参数(如ReLU,pool,DropOut等)(后面这几个是在nn.functional中)
# 在图片分类问题中,输入m张图片,输出一个m*N的Tensor,其中N是分类类别总数。
# 比如输入2张图片,分三类,最后的输出是一个2*3的Tensor,举个例子:
# torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)
output = torch.randn(2, 3)
print(output)
# tensor([[-1.1639, 0.2698, 1.5513],
# [-1.0839, 0.3102, -0.8798]])
# 第1,2行分别是第1,2张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
# 可以看出模型认为第一张为猪,第二张为狗。 然后对每一行使用Softmax,这样可以得到每张图片的概率分布。
print(F.softmax(output,dim=1))
# tensor([[0.1167, 0.1955, 0.6878],
# [0.8077, 0.0990, 0.0933]])
log_Softmax:在Softmax的基础上进行取对数运算
# log_softmax
print(F.log_softmax(output,dim=1))
print(torch.log(F.softmax(output,dim=1)))
tensor([[-1.8601, -0.7688, -0.9655],
[-0.9205, -1.1949, -1.2075]])
tensor([[-1.8601, -0.7688, -0.9655],
[-0.9205, -1.1949, -1.2075]]) # 结果是一致的
NLLLoss:对log_softmax和one-hot编码进行运算
# NLLLoss
print(F.nll_loss(torch.tensor([[-1.2, -0.03, -0.5]]), torch.tensor([0])))
注:Tensor是张量,所以至少为[[]]!!!
# 通常我们结合 log_softmax 和 nll_loss一起用
output = torch.tensor([[1.2,3,2.6]])
target = torch.tensor([0])
print("output为[[1.2,3,2.6]],若target为第一个,nll_loss为:",F.nll_loss(output,target))
target = torch.tensor([1])
print("output为[[1.2,3,2.6]],若target为第二个,nll_loss为:",F.nll_loss(output,target))
target = torch.tensor([2])
print("output为[[1.2,3,2.6]],若target为第二个,nll_loss为:",F.nll_loss(output,target))
输出结果:
output为[[1.2,3,2.6]],若target为第一个,nll_loss为: tensor(-1.2000)
output为[[1.2,3,2.6]],若target为第二个,nll_loss为: tensor(-3.)
output为[[1.2,3,2.6]],若target为第二个,nll_loss为: tensor(-2.6000)
CrossEntropy:衡量两个概率分布的差别
output = torch.tensor([[1.2,3,2.6]])
log_softmax_output = F.log_softmax(output,dim=1)
target = torch.tensor([0])
print(F.nll_loss(log_softmax_output,target))
print(F.cross_entropy(output,target)) # 交叉熵自带softmax
输出结果:
tensor(2.4074)
tensor(2.4074)
图解KD
图中猫的图片的one-hot编码先输入到Teacher网络中进行训练得到q’,在通过蒸馏得到q’’,最后得到soft targets,然后再把猫的图片输入到Student网络中,得到hard targets并计算损失函数,最后和来自Teacher网络预测结果的损失函数相加得到最后的损失函数。
知识蒸馏过程
知识蒸馏应用场景
知识蒸馏和迁移学习的基本区别
迁移学习:是从一个领域获取得模型应用到别的领域的学习
知识蒸馏:是在同一个领域中,从大模型迁移到小模型上的学习