【模型压缩】关于知识蒸馏(Distill)的一次实验
1. 简介
知识蒸馏(Knowledge Distill)旨在使用一个复杂的教师网络(Teacher Net)来引导一个小的学生网络(Student Net),从而达到压缩模型的目的。在普通的有监督训练中,我们一般使用一个数来作为标签,从而引导网络向这个具体的数逼近。而在知识蒸馏方法中,由于教师网络的函数已知,学生网络的函数未知。如果通过已知函数的输出来控制学生网络的函数向该函数拟合,这显然可行。
2. 网络模型的定义
我的思路是:先训练一个大的教师网络,然后使用教师网络的特征图作为标签和数据集的标签一起计算学生网络的损失。具体做法如下。
1️⃣首先定义不同的残差块,作为定义学生网络和教师网络的基本模块。
如上图所示,左边作为教师网络的基本模块,右边作为学生网络的基本模块。每个模块中,学生网络都比教师网络的特征提取能力差。
Leak是因为原来使用的是LeakyReLU。😅
class Res_ConvBNLeak_2(nn.Module):
'''
基本残差块,只有通道与像素融合的作用,通道数不变
'''
def __init__(self,in_channel):
super(Res_ConvBNLeak_2, self).__init__()
self.conv1=nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=in_channel,
kernel_size=3,padding=1,stride=1),
nn.BatchNorm2d(in_channel),
nn.Hardswish())
# Mish())
# nn.ReLU())
self.conv2=nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=in_channel,
kernel_size=3,padding=1,stride=1),
nn.BatchNorm2d(in_channel),
nn.Hardswish())
# Mish())
# nn.ReLU())
def forward(self,x):
identity=x
out=self.conv1(x)
out=self.conv2(out)
return identity+out
class Res_ConvBNLeak_1(nn.Module):
'''
基本残差块,只有通道与像素融合的作用,通道数不变
'''
def __init__(self,in_channel):
super(Res_ConvBNLeak_1, self).__init__()
self.conv1=nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=in_channel,
kernel_size=3,padding=1,stride=1),
nn.BatchNorm2d(in_channel),
nn.LeakyReLU())
def forward(self,x):
identity=x
out=self.conv1(x)
return identity+out
2️⃣然后定义不同的阶段,每个阶段都是先下采样,然后进行通道融合和像素融合。
如上图所以,先是下采样层,让后就是残差层,残差层根据不同的阶段重复不同的次数。
def __make_stage(self,stage):
moduel=OrderedDict()
if stage==0 or stage>3:
raise ValueError("stage value should in [0,4],but got {}".format(stage))
repeat=self.repeat[stage]
stage_name="stage{}".format(stage)
down_name=stage_name+'_downsample'
down=downsampler(self.inchannel_nums[stage-1],self.inchannel_nums[stage])
moduel[down_name]=down
for i in range(repeat):
moduel_name=stage_name+'_repeat_moduel_{}'.format(i)
res_blk=Res_ConvBNLeak_2(in_channel=self.inchannel_nums[stage])
moduel[moduel_name]=res_blk
return nn.Sequential(moduel)
4️⃣最后,网络结构如下。
学生网络(4.49M):
class StudentNet(nn.Module):
def __init__(self):
super(StudentNet, self).__init__()
# self.repeat=[-1,2,6,4]#
# self.repeat = [-1, 1, 3, 2] #
self.repeat = [-1, 1, 1, 1] #
self.inchannel_nums=[32,64,128,256]
# 32*32->16*16
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1,stride=2),
nn.BatchNorm2d(32),
nn.LeakyReLU()
)
#16*16->8*8
self.stage2=self.__make_stage(1)
# 8*8->4*4
self.stage3=self.__make_stage(2)
# 4*4->2*2
self.stage4=self.__make_stage(3)
#clasifier
self.fc=nn.Sequential(
nn.Linear(256*4,10),
nn.LogSoftmax(dim=1)
)
def forward(self,x):
'''
每个阶段都输出,好做loss
'''
out1=self.conv1(x)
out2=self.stage2(out1)
out3=self.stage3(out2)
out4=self.stage4(out3)
out=out4.reshape(-1,256*4)
OUT=self.fc(out)
return out1,out2,out3,out4,OUT
def __make_stage(self,stage):
moduel=OrderedDict()
if stage==0 or stage>3:
raise ValueError("stage value should in [0,4],but got {}".format(stage))
repeat=self.repeat[stage]
stage_name="stage{}".format(stage)
down_name=stage_name+'_downsample'
down=downsampler(self.inchannel_nums[stage-1],self.inchannel_nums[stage])
moduel[down_name]=down
for i in range(repeat):
moduel_name=stage_name+'_repeat_moduel_{}'.format(i)
res_blk=Res_ConvBNLeak_1(in_channel=self.inchannel_nums[stage])
moduel[moduel_name]=res_blk
return nn.Sequential(moduel)
使用工具可视化,如下。
教师网络(26.9M):
class TeacherNet(nn.Module):
def __init__(self):
super(TeacherNet, self).__init__()
self.repeat=[-1,2,6,4]#
# self.repeat = [-1, 4, 12, 6] #
self.inchannel_nums=[32,64,128,256]
# 32*32->16*16
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1,stride=2),
nn.BatchNorm2d(32),
nn.LeakyReLU()
)
#16*16->8*8
self.stage2=self.__make_stage(1)
# 8*8->4*4
self.stage3=self.__make_stage(2)
# 4*4->2*2
self.stage4=self.__make_stage(3)
#clasifier
self.fc=nn.Sequential(
nn.Linear(256*4,10),
nn.LogSoftmax(dim=1)
)
def forward(self,x):
'''
每个阶段都输出,好做loss
'''
out1=self.conv1(x)
out2=self.stage2(out1)
out3=self.stage3(out2)
out4=self.stage4(out3)
out=out4.reshape(-1,256*4)
OUT=self.fc(out)
return out1,out2,out3,out4,OUT
def __make_stage(self,stage):
moduel=OrderedDict()
if stage==0 or stage>3:
raise ValueError("stage value should in [0,4],but got {}".format(stage))
repeat=self.repeat[stage]
stage_name="stage{}".format(stage)
down_name=stage_name+'_downsample'
down=downsampler(self.inchannel_nums[stage-1],self.inchannel_nums[stage])
moduel[down_name]=down
for i in range(repeat):
moduel_name=stage_name+'_repeat_moduel_{}'.format(i)
res_blk=Res_ConvBNLeak_2(in_channel=self.inchannel_nums[stage])
moduel[moduel_name]=res_blk
return nn.Sequential(moduel)
使用工具可视化,如下。
3.训练
使用CIFAR-10
作为数据集,数据集的标签使用NLLLoss
,教师网络的标签使用MSELoss
计算。
在训练的时候在测试集上计算top-1
和top-5
精度。
计算代码如下。
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k,topk=top(1,5) means compute top1 and top5"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
损失计算如下:
for i,(img_,label_) in enumerate(train_dataloader):
img,label=img_.cuda(),label_.cuda()
out1,out2,out3,out4,out=net2(img)
OUT1,OUT2,OUT3,OUT4,_=net1(img)
loss0=nn.functional.nll_loss(out,label)#学生网络的标签
loss1=nn.functional.mse_loss(out1,OUT1)*pi
loss2=nn.functional.mse_loss(out2,OUT2)*(pi**2)
loss3 = nn.functional.mse_loss(out3, OUT3) * (pi ** 3)
loss4 = nn.functional.mse_loss(out4, OUT4) * (pi ** 4)
loss=loss0+loss1+loss2+loss3+loss4#学生网络使用标签
loss = loss1 + loss2 + loss3 + loss4#学生网络不使用标签
4.结果
对于使用的教师网络,测试结果如下:
Net:teachernet
Dataset:CIFAR10-test Len:10000
Test_Batch_Size:1024
Test_Epoch:10
use saved_teachernet\teachernet_@epoch140.pt
Act Hardswish
@test EOPCH:1 top1 acc:74.25130462646484 top5 acc:97.13541412353516
@test EOPCH:2 top1 acc:74.39236450195312 top5 acc:97.15711975097656
@test EOPCH:3 top1 acc:74.21875 top5 acc:97.11371612548828
@test EOPCH:4 top1 acc:74.13194274902344 top5 acc:97.04861450195312
@test EOPCH:5 top1 acc:74.16449737548828 top5 acc:97.08116149902344
@test EOPCH:6 top1 acc:74.33810424804688 top5 acc:97.09201049804688
@test EOPCH:7 top1 acc:74.56597137451172 top5 acc:97.15711975097656
@test EOPCH:8 top1 acc:74.26215362548828 top5 acc:97.16796875
@test EOPCH:9 top1 acc:74.16449737548828 top5 acc:97.21137237548828
@test EOPCH:10 top1 acc:74.68533325195312 top5 acc:97.13541412353516
@test EOPCH:11 top1 acc:74.35980987548828 top5 acc:97.12456512451172
对于直接使用数据集的学生网络,测试结果如下:
Net:studentnet
Dataset:CIFAR10-test Len:10000
Test_Batch_Size:1024
Test_Epoch:10
use student_without_teacher\studentnet_@epoch220.pt
Act Hardswish
PS:compare theachernet:base;student net:1-1-1 without teacher,acc
@test EOPCH:1 top1 acc:70.27994537353516 top5 acc:96.89669799804688
@test EOPCH:2 top1 acc:70.36675262451172 top5 acc:96.90755462646484
@test EOPCH:3 top1 acc:70.32334899902344 top5 acc:96.94010162353516
@test EOPCH:4 top1 acc:70.29080200195312 top5 acc:96.84244537353516
@test EOPCH:5 top1 acc:70.46440887451172 top5 acc:97.00521087646484
@test EOPCH:6 top1 acc:70.14974212646484 top5 acc:96.88584899902344
@test EOPCH:7 top1 acc:70.17144012451172 top5 acc:96.875
@test EOPCH:8 top1 acc:70.35590362548828 top5 acc:96.77734375
@test EOPCH:9 top1 acc:70.12803649902344 top5 acc:96.89669799804688
@test EOPCH:10 top1 acc:70.46440887451172 top5 acc:96.89669799804688
@test EOPCH:11 top1 acc:70.45355987548828 top5 acc:96.875
使用教师网络的输出数据后,学生网络的测试结果如下:
Net:studentnet
Dataset:CIFAR10-test Len:10000
Test_Batch_Size:1024
Test_Epoch:10
use saved_studentnet\dstill_studentnet_@epoch320.pt
Act Hardswish
PS:compare theachernet:base;student net:1-1-1
@test EOPCH:1 top1 acc:74.37065887451172 top5 acc:97.49349212646484
@test EOPCH:2 top1 acc:74.44661712646484 top5 acc:97.45008850097656
@test EOPCH:3 top1 acc:74.4140625 top5 acc:97.53689575195312
@test EOPCH:4 top1 acc:74.35980987548828 top5 acc:97.47178649902344
@test EOPCH:5 top1 acc:74.69618225097656 top5 acc:97.53689575195312
@test EOPCH:6 top1 acc:74.50086975097656 top5 acc:97.50434112548828
@test EOPCH:7 top1 acc:74.21875 top5 acc:97.43923950195312
@test EOPCH:8 top1 acc:74.65277862548828 top5 acc:97.61284637451172
@test EOPCH:9 top1 acc:74.50086975097656 top5 acc:97.49349212646484
@test EOPCH:10 top1 acc:74.45746612548828 top5 acc:97.43923950195312
@test EOPCH:11 top1 acc:74.62022399902344 top5 acc:97.51519012451172
精度 | 教师网络 | 学生网络 | 有教师的学生网络 |
---|---|---|---|
top1 | 74.321339 | 70.31348627 | 74.47620808 |
top5 | 97.12949787 | 96.89019012 | 97.49940976 |
似乎结果使用教师网络的学生网路的精度比教师网络的精度要高些,看来教师网络还没有训练的最好的状态。