一、蒸馏流程
- 加载学生和教师模型
- 定义逻辑蒸馏loss
- 计算逻辑蒸馏loss
- 定义提取特征层函数
- 定义特征蒸馏loss
- 计算特征蒸馏loss
- 计算学生loss和总loss
- 正常训练
二、代码
(1) 加载学生和教师模型
# Student model
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
# Teacher model
LOGGER.info(f'Loaded teacher model {t_cfg}') # report
t_ckpt = torch.load(t_weights, map_location=device) # load checkpoint
t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = t_ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude) # intersect
t_model.load_state_dict(csd, strict=False) # load
(2) 定义逻辑蒸馏loss
def compute_distillation_output_loss(p, t_p, model, d_weight=1):
t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0])
h = model.hyp # hyperparameters
red = 'mean' # Loss reduction (sum or mean)
if red != "mean":
raise NotImplementedError("reduction must be mean in distillation mode!")
DboxLoss = nn.MSELoss(reduction="none")
DclsLoss = nn.MSELoss(reduction="none")
DobjLoss = nn.MSELoss(reduction="none")
for i, pi in enumerate(p): # layer index, layer predictions
t_pi = t_p[i]
t_obj_scale = t_pi[..., 4].sigmoid()
b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)
t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale)
if model.nc > 1: # cls loss (only if multiple classes)
c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc)
t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale)
t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale)
t_lbox *= h['box']
t_lobj *= h['obj']
t_lcls *= h['cls']
loss = (t_lobj + t_lbox + t_lcls) * d_weight
return loss
(3) 计算逻辑蒸馏loss
# Forward
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
with torch.no_grad():
t_pred = t_model(imgs)
d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)
(4) 定义提取特征层函数
if opt.d_feature:
activation = {}
def get_activation(name):
def hook(model, inputs, outputs):
activation[name] = outputs
return hook
def get_hooks():
hooks = []
# # S-model
# hooks.append(model.model._modules["3"].register_forward_hook(get_activation("s_f1")))
# hooks.append(model.model._modules["5"].register_forward_hook(get_activation("s_f2")))
# hooks.append(model.model._modules["9"].register_forward_hook(get_activation("s_f3")))
# # T-model
# hooks.append(t_model.model._modules["3"].register_forward_hook(get_activation("t_f1")))
# hooks.append(t_model.model._modules["5"].register_forward_hook(get_activation("t_f2")))
# hooks.append(t_model.model._modules["9"].register_forward_hook(get_activation("t_f3")))
# S-model
# for k, v in t_model._modules.items():
# print(f"tmodel._modules_k: {k}; v: {v}")
hooks.append(model._modules['module'].model[4].register_forward_hook(get_activation("s_f1")))
hooks.append(model._modules['module'].model[6].register_forward_hook(get_activation("s_f2")))
hooks.append(model._modules['module'].model[9].register_forward_hook(get_activation("s_f3")))
# T-model
hooks.append(t_model._modules['model'][4].register_forward_hook(get_activation("t_f1")))
hooks.append(t_model._modules['model'][6].register_forward_hook(get_activation("t_f2")))
hooks.append(t_model._modules['model'][9].register_forward_hook(get_activation("t_f3")))
return hooks
# feature convert
from models.common import Converter
c1 = 192
c2 = 384
c3 = 768
S_Converter_1 = Converter(128, c1, act=True)
S_Converter_2 = Converter(256, c2, act=True)
S_Converter_3 = Converter(512, c3, act=True)
S_Converter_1.to(device)
S_Converter_2.to(device)
S_Converter_3.to(device)
S_Converter_1.train()
S_Converter_2.train()
S_Converter_3.train()
T_Converter_1 = nn.ReLU6()
T_Converter_2 = nn.ReLU6()
T_Converter_3 = nn.ReLU6()
# T_Converter_1 = Converter(c1, 32, act=True)
# T_Converter_2 = Converter(c2, 96, act=True)
# T_Converter_3 = Converter(c3, 320, act=True)
T_Converter_1.to(device)
T_Converter_2.to(device)
T_Converter_3.to(device)
T_Converter_1.train()
T_Converter_2.train()
T_Converter_3.train()
(5) 定义特征蒸馏loss
def compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1):
"""
Feature Map distillation.
Args:
s_f: student feature
t_f: teacher feature
model: model
Returns: distillation feature loss
"""
h = model.hyp # hyperparameters
ft = torch.cuda.FloatTensor if s_f[0].is_cuda else torch.Tensor
dl_1, dl_2, dl_3 = ft([0]), ft([0]), ft([0])
loss_func1 = nn.MSELoss(reduction="mean")
loss_func2 = nn.MSELoss(reduction="mean")
loss_func3 = nn.MSELoss(reduction="mean")
dl_1 += loss_func1(s_f[0], t_f[0])
dl_2 += loss_func2(s_f[1], t_f[1])
dl_3 += loss_func3(s_f[2], t_f[2])
# bs = s_f[0].shape[0]
return (dl_1 + dl_2 + dl_3) * f_weight
(6) 计算特征蒸馏loss
if opt.d_feature:
hooks = get_hooks()
if opt.d_feature:
s_f1 = S_Converter_1(activation["s_f1"])
s_f2 = S_Converter_2(activation["s_f2"])
s_f3 = S_Converter_3(activation["s_f3"])
s_f = [s_f1, s_f2, s_f3]
t_f1 = T_Converter_1(activation["t_f1"])
t_f2 = T_Converter_2(activation["t_f2"])
t_f3 = T_Converter_3(activation["t_f3"])
t_f = [t_f1, t_f2, t_f3]
if opt.d_feature:
d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)
if opt.d_feature:
for hook in hooks:
hook.remove()
(7) 计算学生loss和总loss
# 学生loss
s_loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if opt.d_feature:
# 特征蒸馏 + 逻辑蒸馏
d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)
loss = d_outputs_loss + s_loss + d_feature_loss
else:
# 逻辑蒸馏
loss = d_outputs_loss + s_loss
(8) 正常训练
# Backward
scaler.scale(loss).backward()