一、MGD特征蒸馏介绍
掩码生成蒸馏(mask generative distillaton, MGD
)通过掩码学生特征的随机像素,并通过一个简单的块强制其生成教师的完整特征。它是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。
参考论文:
二、MGD特征蒸馏实现流程
(1)align
学生模型与教师模型的特征通道
# ---- 特征cheannels对齐 ----#
if student_channels != teacher_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
else:
self.align = None
(2)对学生特征进行随机Mask
# Masked student feature
mat = torch.rand((N, C, 1, 1)).to(device)
mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device)
masked_fea = torch.mul(s_pred, mat)
(3)定义Generation
,并让学生的block
还原教师的全部特征
# ---- 生成 ----#
self.generation = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))
# ---- 让学生的block还原教师的全部特征 ----#
new_fea = self.generation(masked_fea)
(4)计算loss
# loss
L2 = nn.MSELoss(reduction='sum')
loss = L2(new_fea, t_pred) / N # N:batch
三、完整MGD特征蒸馏代码实现
# ---- MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割-------#
class MGDLoss(nn.Module):
def __init__(self,student_channels,teacher_channels,name,alpha_mgd=0.00007,lambda_mgd=0.5,):
super(MGDLoss, self).__init__()
self.alpha_mgd = alpha_mgd
self.lambda_mgd = lambda_mgd
# ---- 特征cheannels对齐 ----#
if student_channels != teacher_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
else:
self.align = None
# ---- 生成 ----#
self.generation = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))
def forward(self,s_pred,t_pred):
# ---- 检查特征cheannels是否对齐 ----#
assert s_pred.shape[-2:] == t_pred.shape[-2:]
if self.align is not None:
preds_S = self.align(s_pred)
# ---- 计算loss ----#
loss = self.get_dis_loss(preds_S, t_pred) * self.alpha_mgd
return loss
def get_dis_loss(self, s_pred, t_pred):
L2 = nn.MSELoss(reduction='sum')
N, C, H, W = t_pred.shape
device = s_pred.device
# Masked student block
mat = torch.rand((N, C, 1, 1)).to(device)
mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device)
masked_fea = torch.mul(s_pred, mat)
# 让学生的block还原教师的全部特征
new_fea = self.generation(masked_fea)
# loss
loss = L2(new_fea, t_pred) / N
return loss