Dynamic Slimmable Network-(CVPR21-ORAL)综合分析

文章提出了一种动态剪枝策略,通过动态宽度可变超网络(Dynamic Slimmable Supernet),解决了传统剪枝方法硬件实际加速效果很低的问题。并提出了动态宽度门控(Dynamic Slimming Gate)对网络进行瘦身。在这里插入图片描述
下面就文章原理和代码来综合分析:

动态宽度可变超网络

在这里插入图片描述
动态宽度可变网络(DS-Net)通过学习一个宽度可变超网络和一个动态门控机制来实现不同样本的动态路由。如上图所示,DS-Net中的超网络(上图黄色框)是指承担主要任务的整个模块。相比之下,动态门控(上图蓝色框)是一系列预测模块,它们将输入样本路由到超网络的不同宽度的子网络。

1.动态超网络(supernet)和动态可切分(slice-able)卷积
为避免产生稀疏channel,作者提出动态可切分(slice-able)卷积,通过预测出的剪枝率 ,动态的切分使用卷积的前 × n个滤波器(n为总滤波器数)。通过堆叠动态可切分(slice-able)卷积并禁用动态门控,就形成了类似slimmablenetwork的动态超网络。
如下为conv2d的slice方法。

    def forward(self, x):
        if self.prev_channel_choice is None:
            self.prev_channel_choice = self.channel_choice
        if self.mode == 'dynamic' and isinstance(self.channel_choice, tuple):
            weight = self.weight
            if not self.in_chn_static:
                if isinstance(self.prev_channel_choice, int):
                    self.running_inc = self.in_channels_list[self.prev_channel_choice]
                    weight = self.weight[:, :self.running_inc]
                else:
                    self.running_inc = torch.matmul(self.prev_channel_choice[0], self.in_channels_list_tensor)
            if not self.out_chn_static:
                self.running_outc = torch.matmul(self.channel_choice[0], self.out_channels_list_tensor)

            output = F.conv2d(x,
                              weight,
                              self.bias,
                              self.stride,
                              self.padding,
                              self.dilation,
                              self.groups)
            if not self.out_chn_static:
                output = apply_differentiable_gate_channel(output,
                                                           self.channel_choice[0],
                                                           self.out_channels_list)
            self.prev_channel_choice = None
            self.channel_choice = -1
            return output
        else:
            if not self.in_chn_static:
                self.running_inc = x.size(1)
            if not self.out_chn_static:
                self.running_outc = self.out_channels_list[self.channel_choice]
            weight = self.weight[:self.running_outc, :self.running_inc]
            bias = self.bias[:self.running_outc] if self.bias is not None else None
            self.running_groups = 1 if self.groups == 1 else self.running_outc
            self.prev_channel_choice = None
            self.channel_choice = -1
            return F.conv2d(x,
                            weight,
                            bias,
                            self.stride,
                            self.padding,
                            self.dilation,
                            self.running_groups)

2.In-place Ensemble Bootstrapping(IEB)
之前的slimmable network训练使用in-place distillation方法:最宽的子网络学习预测真实标签,同时生成软标签,并通过知识蒸馏的方式来训练其他较窄的子网络。但in-place distillation训练很不稳定,权重在训练早期会大幅突变,并可能导致模型最终训练失败或性能损失。
在这里插入图片描述

为此,作者提出In-place Ensemble Bootstrapping(IEB)策略来稳定动态超网络的训练并最终提高模型性能。首先,使用超网络的滑动平均(EMA)网络来生成训练子网络的软标签,因为EMA网络提供的目标更加稳定和精准。其次,使用包括最宽子网络和随机宽度子网络的多个模型的概率集成(probability ensemble)作为训练最窄网络的目标,因为多模型集成可以提供更多样、更泛化、更精准的软标签。(见上图)
一下代码为IEB策略代码。

    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()

        sample_list = ['largest', 'uniform', 'uniform', 'smallest']
        guide_list = []

        for sample_idx, model_mode in enumerate(sample_list):
            seed = seed + 1
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            if hasattr(model, 'module'):
                model.module.set_mode(model_mode)
            else:
                model.set_mode(model_mode)
            #model.blocks._modules['3'].first_block.conv_dw.running_inc running_outc
            output = model(input)
 
            if model_mode == 'largest':
                loss = loss_fn(output, target)#原始标签计算loss
                if args.ieb:
                    with torch.no_grad():
                        if hasattr(model_ema.ema, 'module'):
                            model_ema.ema.module.set_mode(model_mode)
                        else:
                            model_ema.ema.set_mode(model_mode)
                        output_largest = model_ema.ema(input)
                    guide_list.append(output_largest)
                loss_largest = loss
            elif model_mode != 'smallest':
                if args.ieb:
                    loss = distill_loss_fn(output, F.softmax(output_largest, dim=1))
                    # with torch.no_grad():
                    #     guide_output = model_ema.ema(input)
                    with torch.no_grad():
                        if hasattr(model_ema.ema, 'module'):
                            model_ema.ema.module.set_mode(model_mode)
                        else:
                            model_ema.ema.set_mode(model_mode)
                        guide_output = model_ema.ema(input)
                    guide_list.append((guide_output))
                    # guide_list.append((output.detach()))
                else:
                    loss = loss_fn(output, target)
            else:  # 'smallest'
                soft_labels_ = [torch.unsqueeze(guide_list[idx], dim=2) for
                                idx in range(len(guide_list))]
                soft_labels_softmax = [F.softmax(i, dim=1) for i in soft_labels_]
                soft_labels_softmax = torch.cat(soft_labels_softmax, dim=2).mean(dim=2)
                if args.ieb:
                    loss = distill_loss_fn(output, soft_labels_softmax)
                else:
                    loss = loss_fn(output, target)
                loss_smallest = loss

            loss = loss / optimizer_step

            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

动态宽度门控

作者设计了双头(double headed)动态宽度门控的结构,并提出gate的训练策略sandwich gate sparsification(SGS)。
在这里插入图片描述
首先,作者提出通过当前层输入特征来预测出一个one-hot向量,对应选择剪枝率列表中的一个值。输入特征首先通过全局池化(global pooling)来消除空间维度,再通过两个以ReLU相隔的全连接层,并求argmax,得到one-hot向量。由于这一结构与频道注意力相似,作者将两者除最后一层全连接外的层进行共享,形成包含动态宽度头和频道注意力头的双头动态宽度门控。
以下为gate的code

class MultiHeadGate(nn.Module):
    def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU,
                 attn_act_fn=sigmoid, divisor=1, channel_gate_num=None, gate_num_features=1024):
        super(MultiHeadGate, self).__init__()
        self.attn_act_fn = attn_act_fn
        self.channel_gate_num = channel_gate_num
        reduced_chs = make_divisible((reduced_base_chs or in_chs[-1]) * se_ratio, divisor)
        self.avg_pool = DSAdaptiveAvgPool2d(1, channel_list=in_chs)
        self.conv_reduce = DSpwConv2d(in_chs, [reduced_chs], bias=True)
        self.act1 = act_layer(inplace=True)
        self.conv_expand = DSpwConv2d([reduced_chs], in_chs, bias=True)

        self.has_gate = False
        if channel_gate_num > 1:
            self.has_gate = True
            # self.gate = nn.Sequential(DSpwConv2d([reduced_chs], [gate_num_features], bias=True),
            #                           act_layer(inplace=True),
            #                           nn.Dropout2d(p=0.2),
            #                           DSpwConv2d([gate_num_features], [channel_gate_num], bias=True))
            self.gate = nn.Sequential(DSpwConv2d([reduced_chs], [channel_gate_num], bias=False))

        self.mode = 'largest'
        self.keep_gate, self.print_gate, self.print_idx = None, None, None
        self.channel_choice = None
        self.initialized = False
        if self.attn_act_fn == 'tanh':
            nn.init.zeros_(self.conv_expand.weight)
            nn.init.zeros_(self.conv_expand.bias)

    def forward(self, x):
        x_pool = self.avg_pool(x)
        x_reduced = self.conv_reduce(x_pool)
        x_reduced = self.act1(x_reduced)
        attn = self.conv_expand(x_reduced)
        if self.attn_act_fn == 'tanh':
            attn = (1 + attn.tanh())
        else:
            attn = self.attn_act_fn(attn)
        x = x * attn

        if self.mode == 'dynamic' and self.has_gate:
            channel_choice = self.gate(x_reduced).squeeze(-1).squeeze(-1)
            self.keep_gate, self.print_gate, self.print_idx = gumbel_softmax(channel_choice, dim=1, training=self.training)
            self.channel_choice = self.print_gate, self.print_idx
        else:
            self.channel_choice = None

        return x

    def get_gate(self):
        return self.channel_choice

2. Sandwich Gate Sparsification(SGS)

由于argmax不可导,之前使用argmax作为网络中间层的工作一般使用gumbel-softmax作为替代,来近似求导,以便梯度回传。但是,本文作者发现采用这个方法进行gate训练时,很容易使其gate坍塌成静态。

为此,作者提出Sandwich Gate Sparsification训练策略。首先,每个输入样本都按,“是否能被最窄网络正确预测”,分为容易和困难两类。然后,将两类难易样本打上one-hot标签,使用交叉熵直接优化gate。这种训练方式避免了间接和近似的梯度回传,克服了gate收敛困难的问题,并提高了gate的动态多样性。

    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
            #torch.jit.trace(model, input).save('dsnet1.pt')

        if last_batch or (batch_idx + 1) % optimizer_step == 0:
            optimizer.zero_grad()
        # generate online labels
        with torch.no_grad():
            set_model_mode(model, 'smallest')
            output = model(input)
            conf_s, correct_s = accuracy(output, target, no_reduce=True)
            gate_target = [torch.LongTensor([0]) if correct_s[0][idx] else torch.LongTensor([3])
                           for idx in range(correct_s[0].size(0))]
            gate_target = torch.stack(gate_target).squeeze(-1).cuda()
        # =============

        set_model_mode(model, 'dynamic')
      
        output = model(input)
      )

        if hasattr(model, 'module'):
            model_ = model.module
        else:
            model_ = model

        #  SGS Loss
        gate_loss = 0
        gate_num = 0
        gate_loss_l = []
        gate_acc_l = []
        for n, m in model_.named_modules():
            if isinstance(m, MultiHeadGate):
                if getattr(m, 'keep_gate', None) is not None:
                    gate_num += 1
                    g_loss = loss_fn(m.keep_gate, gate_target)
                    gate_loss += g_loss
                    gate_loss_l.append(g_loss)
                    gate_acc_l.append(accuracy(m.keep_gate, gate_target, topk=(1,))[0])

        gate_loss /= gate_num

        #  MAdds Loss
        running_flops = add_flops(model)
        if isinstance(running_flops, torch.Tensor):
            running_flops = running_flops.float().mean().cuda()
        else:
            running_flops = torch.FloatTensor([running_flops]).cuda()


        flops_loss = (running_flops / 1e9) ** 2

        #  Target Loss, back-propagate through gumbel-softmax
        ce_loss = loss_fn(output, target)

        loss = gate_loss + ce_loss + 0.5 * flops_loss
        # loss = ce_loss
        acc1 = accuracy(output, target, topk=(1,))[0]

        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

评估gate

    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
        # generate online labels
        with torch.no_grad():
            set_model_mode(model, 'smallest')
            output = model(input)
            conf_s, correct_s = accuracy(output, target, no_reduce=True)
            gate_target = [torch.LongTensor([0]) if correct_s[0][idx] else torch.LongTensor([3])
                           for idx in range(correct_s[0].size(0))]
            gate_target = torch.stack(gate_target).squeeze(-1).cuda()
        # =============
        set_model_mode(model, 'dynamic')
        output = model(input)

        if hasattr(model, 'module'):
            model_ = model.module
        else:
            model_ = model

        gate_acc_l = []
        for n, m in model_.named_modules():
            if isinstance(m, MultiHeadGate):
                if getattr(m, 'keep_gate', None) is not None:
                    gate_acc_l.append(accuracy(m.keep_gate, gate_target, topk=(1,))[0])

        running_flops = add_flops(model)
        if isinstance(running_flops, torch.Tensor):
            running_flops = running_flops.float().mean().cuda()
        else:
            running_flops = torch.FloatTensor([running_flops]).cuda()

        loss = loss_fn(output, target)
        prec1, prec5 = accuracy(output, target, topk=(1, 5))

评估validate_slim

    for choice in range(args.num_choice):
        eval_metrics.append(validate_slim(model,
                                          loader_eval,
                                          validate_loss_fn,
                                          args,
                                          model_mode=choice))

validate_slim

    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            if not isinstance(model_mode, str):
                if hasattr(model, 'module'):
                    model.module.set_mode('uniform', choice=model_mode)
                else:
                    model.set_mode('uniform', choice=model_mode)
            else:
                if hasattr(model, 'module'):
                    model.module.set_mode(model_mode)
                else:
                    model.set_mode(model_mode)
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.cuda()
                target = target.cuda()

            output = model(input)
            running_flops = add_flops(model)
            if isinstance(running_flops, torch.Tensor):
                running_flops = running_flops.float().mean().cuda()
            else:
                running_flops = torch.FloatTensor([running_flops]).cuda()

            # augmentation reduction
            reduce_factor = args.tta
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
            prec1, prec5 = accuracy(output, target, topk=(1, 5))

可以看出,评估gate网络用的是dynamic模式;对网络进行测试时,四个模型一起预测评估。
需要注意的是:该代码用来保留4个channels选项。

    for n, m in model.named_modules():
        if len(getattr(m, 'in_channels_list', [])) > 4:
            m.in_channels_list = m.in_channels_list[start_chn_idx:4]
            m.in_channels_list_tensor = torch.from_numpy(
                np.array(m.in_channels_list)).float().cuda()
        if len(getattr(m, 'out_channels_list', [])) > 4:
            m.out_channels_list = m.out_channels_list[start_chn_idx:4]
            m.out_channels_list_tensor = torch.from_numpy(
                np.array(m.out_channels_list)).float().cuda()

总结

从代码中可以看出,(Stage I: Supernet Training)通过对 [‘largest’, ‘uniform’, ‘uniform’, ‘smallest’]四个submodel进行软标签式训练,该阶段不训练gate,只训练supernet。 (Stage II: Gate Training)该阶段在训练出原始模型后,通过SGS来训练gate(gate在stage的3block.first_block.gate.gate,输出为[66,4])。最终在前传阶段,通过将mode置为dynamic,使得模型进行weight*gate形式前传来预测最终label。对于检测网络,感觉不易实现该压缩算法。

参考文章

  • https://zhuanlan.zhihu.com/p/370208935
  • https://www.163.com/dy/article/G6QJKLHQ0511CQLG.html
  • https://zhuanlan.zhihu.com/p/354043252
  • CVPR 2021 ORAL
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值