【知识蒸馏】yolov5逻辑蒸馏和特征蒸馏实战

一、蒸馏流程

  • 加载学生和教师模型
  • 定义逻辑蒸馏loss
  • 计算逻辑蒸馏loss
  • 定义提取特征层函数
  • 定义特征蒸馏loss
  • 计算特征蒸馏loss
  • 计算学生loss和总loss
  • 正常训练

二、代码

(1) 加载学生和教师模型

# Student model
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create

# Teacher model
LOGGER.info(f'Loaded teacher model {t_cfg}')  # report
t_ckpt = torch.load(t_weights, map_location=device)  # load checkpoint
t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else []  # exclude keys
csd = t_ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude)  # intersect
t_model.load_state_dict(csd, strict=False)  # load

(2) 定义逻辑蒸馏loss

def compute_distillation_output_loss(p, t_p, model, d_weight=1):
    t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
    t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0])
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)
    if red != "mean":
        raise NotImplementedError("reduction must be mean in distillation mode!")
    DboxLoss = nn.MSELoss(reduction="none")
    DclsLoss = nn.MSELoss(reduction="none")
    DobjLoss = nn.MSELoss(reduction="none")
    for i, pi in enumerate(p):  # layer index, layer predictions
        t_pi = t_p[i]
        t_obj_scale = t_pi[..., 4].sigmoid()
        b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)
        t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale)
        if model.nc > 1:  # cls loss (only if multiple classes)
            c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc)
            t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale)
        t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale)
    t_lbox *= h['box']
    t_lobj *= h['obj']
    t_lcls *= h['cls']
    loss = (t_lobj + t_lbox + t_lcls) * d_weight
    return loss

(3) 计算逻辑蒸馏loss

# Forward
with amp.autocast(enabled=cuda):
    pred = model(imgs)  # forward

    with torch.no_grad():
        t_pred = t_model(imgs)

    d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)

(4) 定义提取特征层函数

if opt.d_feature:
    activation = {}

    def get_activation(name):
        def hook(model, inputs, outputs):
            activation[name] = outputs

        return hook

    def get_hooks():
        hooks = []
        # # S-model
        # hooks.append(model.model._modules["3"].register_forward_hook(get_activation("s_f1")))
        # hooks.append(model.model._modules["5"].register_forward_hook(get_activation("s_f2")))
        # hooks.append(model.model._modules["9"].register_forward_hook(get_activation("s_f3")))
        # # T-model
        # hooks.append(t_model.model._modules["3"].register_forward_hook(get_activation("t_f1")))
        # hooks.append(t_model.model._modules["5"].register_forward_hook(get_activation("t_f2")))
        # hooks.append(t_model.model._modules["9"].register_forward_hook(get_activation("t_f3")))

        # S-model
        # for k, v in t_model._modules.items():
        #     print(f"tmodel._modules_k: {k}; v: {v}")
        hooks.append(model._modules['module'].model[4].register_forward_hook(get_activation("s_f1")))
        hooks.append(model._modules['module'].model[6].register_forward_hook(get_activation("s_f2")))
        hooks.append(model._modules['module'].model[9].register_forward_hook(get_activation("s_f3")))
        # T-model
        hooks.append(t_model._modules['model'][4].register_forward_hook(get_activation("t_f1")))
        hooks.append(t_model._modules['model'][6].register_forward_hook(get_activation("t_f2")))
        hooks.append(t_model._modules['model'][9].register_forward_hook(get_activation("t_f3")))
        return hooks

        # feature convert
        from models.common import Converter

        c1 = 192
        c2 = 384
        c3 = 768
        S_Converter_1 = Converter(128, c1, act=True)
        S_Converter_2 = Converter(256, c2, act=True)
        S_Converter_3 = Converter(512, c3, act=True)
        S_Converter_1.to(device)
        S_Converter_2.to(device)
        S_Converter_3.to(device)
        S_Converter_1.train()
        S_Converter_2.train()
        S_Converter_3.train()

        T_Converter_1 = nn.ReLU6()
        T_Converter_2 = nn.ReLU6()
        T_Converter_3 = nn.ReLU6()
        # T_Converter_1 = Converter(c1, 32, act=True)
        # T_Converter_2 = Converter(c2, 96, act=True)
        # T_Converter_3 = Converter(c3, 320, act=True)
        T_Converter_1.to(device)
        T_Converter_2.to(device)
        T_Converter_3.to(device)
        T_Converter_1.train()
        T_Converter_2.train()
        T_Converter_3.train()

(5) 定义特征蒸馏loss

def compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1):
    """
    Feature Map distillation.
    Args:
        s_f: student feature
        t_f: teacher feature
        model: model

    Returns: distillation feature loss
    """
    h = model.hyp  # hyperparameters
    ft = torch.cuda.FloatTensor if s_f[0].is_cuda else torch.Tensor
    dl_1, dl_2, dl_3 = ft([0]), ft([0]), ft([0])

    loss_func1 = nn.MSELoss(reduction="mean")
    loss_func2 = nn.MSELoss(reduction="mean")
    loss_func3 = nn.MSELoss(reduction="mean")

    dl_1 += loss_func1(s_f[0], t_f[0])
    dl_2 += loss_func2(s_f[1], t_f[1])
    dl_3 += loss_func3(s_f[2], t_f[2])

    # bs = s_f[0].shape[0]

    return (dl_1 + dl_2 + dl_3) * f_weight

(6) 计算特征蒸馏loss

if opt.d_feature:
    hooks = get_hooks()

if opt.d_feature:
    s_f1 = S_Converter_1(activation["s_f1"])
    s_f2 = S_Converter_2(activation["s_f2"])
    s_f3 = S_Converter_3(activation["s_f3"])
    s_f = [s_f1, s_f2, s_f3]

    t_f1 = T_Converter_1(activation["t_f1"])
    t_f2 = T_Converter_2(activation["t_f2"])
    t_f3 = T_Converter_3(activation["t_f3"])
    t_f = [t_f1, t_f2, t_f3]

if opt.d_feature:
    d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)

if opt.d_feature:
    for hook in hooks:
        hook.remove()

(7) 计算学生loss和总loss

# 学生loss
s_loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size

if opt.d_feature:
    # 特征蒸馏 + 逻辑蒸馏
    d_feature_loss = compute_distillation_feature_loss(s_f, t_f, model, f_weight=0.1)
    loss = d_outputs_loss + s_loss + d_feature_loss
else:
    # 逻辑蒸馏
    loss = d_outputs_loss + s_loss

(8) 正常训练

# Backward
scaler.scale(loss).backward()
  • 9
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值