CVPR 2020 DynamicRouting 源码学习

Learning Dynamic Routing for Semantic Segmentation (CVPR 2020 Open Access)
github源码
旷视科技-知乎-论文解读
Dynamic Routing会自适应地生成不同的结构进行特征编码,网络可以将不同尺寸的物体(或背景)分配到对应分辨率的层级上,以实现有针对性的特征变换。,相较于Auto Deeplab,Dynamic Routing支持多路径链接、跳跃连接。
之前没有阅读过动态路径选择的模型,在阅读论文时有很多不理解的地方,拿源码来学习一下,如果有不对的地方还请大家指出,欢迎交流(,ԾㅂԾ,)。

有关budget constrain

C ( N o d e s l ) = C ( C e l l s l ) + C ( G a t e s l ) + C ( T r a n s s l ) = m a x ( α s l ) ∑ O i ∈ O C ( O i ) + C ( G a t e s l ) + ∑ j α s → j l C ( T s → j ) C ( S p a c e ) = ∑ l ≤ L ∑ s ≤ 1 / 4 C ( N o d e s l ) L C = ( C ( S p a c e ) / C − μ ) 2 L = λ 1 L N + λ 2 L C \begin{array}{l}\mathcal{C}(Node^l_s)&= \mathcal{C}(Cell^l_s) + \mathcal{C}(Gate^l_s) + \mathcal{C}(Trans^l_s)\\& = max(α_s^l ) \sum_{O^i\in\mathcal{O}}\mathcal{C}(O^i) + \mathcal{C}(Gate^l_s)+\sum_j α_{s→j}^l\mathcal{C}(\mathcal{T}_{s→j})\end{array} \\\mathcal{C}(Space)=\sum_{l\le L}\sum_{s\le 1/4}\mathcal{C}(Node_s^l)\\ \mathcal{L}_C=(\mathcal{C}(Space)/C-\mu)^2\\ \mathcal{L}=\lambda_1\mathcal{L}_N+\lambda_2\mathcal{L}_C C(Nodesl)=C(Cellsl)+C(Gatesl)+C(Transsl)=max(αsl)OiOC(Oi)+C(Gatesl)+jαsjlC(Tsj)C(Space)=lLs1/4C(Nodesl)LC=(C(Space)/Cμ)2L=λ1LN+λ2LC
上式中, C \mathcal{C} C为相关运算的FLOPs, C C C为网络真实的FLOPs,最后的损失函数包括budget constraint项和语义分割损失项,budget constraint与各个Cell的 α \alpha α(Soft Conditional Gate,用于控制特征输出的程度)相关。
此外, λ 2 = m i n ( 1 , ( c u r r e n t _ s t e p / t o t a l _ s t e p − U N U P D A T E _ R A T E ) / 0.2 ) \lambda_2=min(1,(current\_step/total\_step-UNUPDATE\_RATE)/0.2) λ2=min(1,current_step/total_stepUNUPDATE_RATE/0.2)

在模型前向传播的过程中,会顺便计算相关的一些FLOPs。

dynamic4seg.py

用于语义分割的模型的整体架构可以在dl_lib/modeling/meta_arch/dynamic4seg.py中找到,这个文件中定义了整个语义分割的流程。
在获取到batch_inputs之后,对input_batch进行归一化处理,然后输入到backbone,得到中间层的特征features、用于budget constrain的expt_flops、用于评估模型计算成本的backbone的FLOPs real_flops;然后将features输入到sem_seg_head中进行上采样,得到分割结果和损失loss;如果是在训练过程中,则将expt_flops加入到最终的损失loss中,用于反向传播,约束模型的复杂度;如果是在inference阶段,则上采样(双线性插值)到原图大小,返回最终分割结果(输入的图像往往会进行resize,与原图大小有差异)。
backbone、sem_seg_head的结构分别定义在dl_lib/modeling/dynamic_arch/dynamic_backbone.py和当前的dynamic4seg.py中。
模型初始化参数可参考playground/Dynamic/目录下的各个模型的config文件

class DynamicNet4Seg(nn.Module):
    """
    This module implements Dynamic Network for Semantic Segmentation.
    """
    def __init__(self, cfg):
        super().__init__()
        self.constrain_on = cfg.MODEL.BUDGET.CONSTRAIN
        self.unupdate_rate = cfg.MODEL.BUDGET.UNUPDATE_RATE
        self.device = torch.device(cfg.MODEL.DEVICE)
        self.backbone = cfg.build_backbone(cfg)
        self.sem_seg_head = cfg.build_sem_seg_head(
            cfg, self.backbone.output_shape())
        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(
            -1, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(
            -1, 1, 1)
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
        self.budget_constrint = BudgetConstraint(cfg)
        self.to(self.device)

    def forward(self, batched_inputs, step_rate=0.0):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
            step_rate: a float, calculated by current_step/total_step,
                This parameter is used for Scheduled Drop Path.
        For now, each item in the list is a dict that contains:
            image: Tensor, image in (C, H, W) format.
            sem_seg: semantic segmentation ground truth
            Other information that's included in the original dicts, such as:
                "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.
        Returns:
            list[dict]: Each dict is the output for one input image.
                The dict contains one key "sem_seg" whose value is a
                Tensor of the output resolution that represents the
                per-pixel segmentation prediction.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images,
                                        self.backbone.size_divisibility)

        features, expt_flops, real_flops = self.backbone(
            images.tensor, step_rate)

        if "sem_seg" in batched_inputs[0]:
            targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
            targets = ImageList.from_tensors(
                targets, self.backbone.size_divisibility,
                self.sem_seg_head.ignore_value).tensor
        else:
            targets = None

        results, losses = self.sem_seg_head(features, targets)
        # calculate flops
        real_flops += self.sem_seg_head.flops
        flops = {'real_flops': real_flops, 'expt_flops': expt_flops}
        # use budget constraint for training
        if self.training:
            if self.constrain_on and step_rate >= self.unupdate_rate:
                warm_up_rate = min(
                    1.0, (step_rate - self.unupdate_rate) / 0.02
                )
                loss_budget = self.budget_constrint(
                    expt_flops, warm_up_rate=warm_up_rate
                )
                losses.update({'loss_budget': loss_budget})
            return losses, flops

        processed_results = []
        for result, input_per_image, image_size in zip(results, batched_inputs,
                                                       images.image_sizes):
            height = input_per_image.get("height")
            width = input_per_image.get("width")
            r = sem_seg_postprocess(result, image_size, height, width)
            processed_results.append({"sem_seg": r, "flops": flops})
        return processed_results

sem_seg_head主要用于对提取到的多尺度特征进行上采样和特征融合,的计算过程如下:
根据输入feature的个数初始化模型 C o n v 1 × 1 Conv1\times 1 Conv1×1+bach_norm+relu的decoder的层数,在将底层特征用decoder layer做解码后,使用双线性插值将特征图上采样到上一层的大小并按位相加进行融合,最后使用 C o n v 3 × 3 Conv3\times 3 Conv3×3+双线性插值上采样到原图大小得到result,如果是训练阶段,计算mean pixel cross entropy loss,返回分割结果和loss。

class SemSegDecoderHead(nn.Module):
    """
    This module implements simple decoder head for Semantic Segmentation.
    It creats decoder on top of the dynamic backbone.
    """
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()
        # fmt: off
        self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
        feature_strides = {k: v.stride for k, v in input_shape.items()}  # noqa:F841
        feature_channels = {k: v.channels for k, v in input_shape.items()}
        feature_resolution = {
            k: np.array([v.height, v.width])
            for k, v in input_shape.items()
        }
        self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
        num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        norm = cfg.MODEL.SEM_SEG_HEAD.NORM
        self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT
        self.cal_flops = cfg.MODEL.CAL_FLOPS
        self.real_flops = 0.0
        # fmt: on

        self.layer_decoder_list = nn.ModuleList()
        # set affine in BatchNorm
        if 'Sync' in norm:
            affine = True
        else:
            affine = False
        # use simple decoder
        for _feat in self.in_features:
            res_size = feature_resolution[_feat]
            in_channel = feature_channels[_feat]
            if _feat == 'layer_0':
                out_channel = in_channel
            else:
                out_channel = in_channel // 2
            conv_1x1 = Conv2d(in_channel,
                              out_channel,
                              kernel_size=1,
                              stride=1,
                              padding=0,
                              bias=False,
                              norm=get_norm(norm, out_channel),
                              activation=nn.ReLU())
            self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
                res_size[0],
                res_size[1],
                in_channel,
                out_channel, [1, 1],
                is_affine=affine)
            self.layer_decoder_list.append(conv_1x1)
        # using Kaiming init
        for layer in self.layer_decoder_list:
            weight_init.kaiming_init_module(layer, mode='fan_in')
        in_channel = feature_channels['layer_0']
        # the output layer
        self.predictor = Conv2d(in_channels=in_channel,
                                out_channels=num_classes,
                                kernel_size=3,
                                stride=1,
                                padding=1)
        self.real_flops += cal_op_flops.count_Conv_flop(
            feature_resolution['layer_0'][0], feature_resolution['layer_0'][1],
            in_channel, num_classes, [3, 3])
        # using Kaiming init
        weight_init.kaiming_init_module(self.predictor, mode='fan_in')

    def forward(self, features, targets=None):
        pred, pred_output = None, None
        for _index in range(len(self.in_features)):
            out_index = len(self.in_features) - _index - 1
            out_feat = features[self.in_features[out_index]]
            if out_index <= 2:
                out_feat = pred + out_feat
            pred = self.layer_decoder_list[out_index](out_feat)
            if out_index > 0:
                pred = F.interpolate(input=pred,
                                     scale_factor=2,
                                     mode='bilinear',
                                     align_corners=False)
            else:
                pred_output = pred
        # pred output
        pred_output = self.predictor(pred_output)
        pred_output = F.interpolate(input=pred_output,
                                    scale_factor=4,
                                    mode='bilinear',
                                    align_corners=False)

        if self.training:
            losses = {}
            losses["loss_sem_seg"] = (
                F.cross_entropy(
                    pred_output, targets, reduction="mean",
                    ignore_index=self.ignore_value
                ) * self.loss_weight
            )
            return [], losses
        else:
            return pred_output, {}

    @property
    def flops(self):
        return self.real_flops
dynamic_backbone.py

dl_lib/modeling/dynamic_arch/dynamic_backbone.py 中定义了模型的backbone结构和前向传播过程。初始化参数可参考playground/Dynamic/目录下的各个模型的config文件。
input首先通过STEM模块(3层 C o n v 3 × 3 Conv3\times 3 Conv3×3)下采样到输入的1/4,然后用一个Cell单元对STEM得到的特征进行初始化,便于后面进行Cell单元的计算。当前层 l l l的Cell单元的输入为上一层输入当前层的特征 Y Y Y之和, X s l = Y s / 2 l − 1 + Y s l − 1 + Y 2 s l − 1 X_s^l=Y^{l-1}_{s/2}+Y^{l-1}_{s}+Y^{l-1}_{2s} Xsl=Ys/2l1+Ysl1+Y2sl1,然后逐单元计算输出。

class DynamicNetwork(Backbone):
    """
    This module implements Dynamic Routing Network.
    It creates dense connected network on top of some input feature maps.
    """
    def __init__(
        self, init_channel, input_shape, cell_num_list, layer_num,
        ext_layer=None, norm="", cal_flops=True, cell_type='',
        max_stride=32, sep_stem=True, using_gate=False,
        small_gate=False, gate_bias=1.5, drop_prob=0.0,
    ):
        super(DynamicNetwork, self).__init__()
        # set affine in BatchNorm
        if 'Sync' in norm:
            self.affine = True
        else:
            self.affine = False
        # set scheduled drop path
        self.drop_prob = drop_prob
        if self.drop_prob > 0.0001:
            self.drop_path = True
        else:
            self.drop_path = False
        self.cal_flops = cal_flops
        self._size_divisibility = max_stride
        input_res = np.array(input_shape[1:3])

        self.stem = DynamicStem(
            3, out_channels=init_channel, input_res=input_res,
            sept_stem=sep_stem, norm=norm, affine=self.affine
        )
        self.stem_flops = self.stem.flops
        self._out_feature_strides = {"stem": self.stem.stride}
        self._out_feature_channels = {"stem": self.stem.out_channels}
        self._out_feature_resolution = {"stem": self.stem.out_resolution}
        assert self.stem.out_channels == init_channel
        self.all_cell_list = nn.ModuleList()
        self.all_cell_type_list = []
        self.cell_num_list = cell_num_list[:layer_num]
        self._out_features = []
        # using the initial layer
        input_res = input_res // self.stem.stride
        in_channel = out_channel = init_channel
        self.init_layer = Cell(
            C_in=in_channel, C_out=out_channel, norm=norm, allow_up=False,
            allow_down=True, input_size=input_res, cell_type=cell_type,
            cal_flops=False, using_gate=using_gate, small_gate=small_gate,
            gate_bias=gate_bias, affine=self.affine
        )

        # add cells in each layer
        for layer_index in range(len(self.cell_num_list)):
            layer_cell_list = nn.ModuleList()
            layer_cell_type = []
            for cell_index in range(self.cell_num_list[layer_index]):
                # channel multi, when stride:4 -> channel:C, stride:8 -> channel:2C ...
                channel_multi = pow(2, cell_index)
                in_channel_cell = in_channel * channel_multi
                # add res and dim switch to each cell
                allow_up = True
                allow_down = True
                # add res up and dim down by 2
                if cell_index == 0 or layer_index == layer_num - 1:
                    allow_up = False
                # dim down and resolution up by 2
                if cell_index == 3 or layer_index == layer_num - 1:
                    allow_down = False
                res_size = input_res // channel_multi
                layer_cell_list.append(
                    Cell(
                        C_in=in_channel_cell, C_out=in_channel_cell, norm=norm,
                        allow_up=allow_up, allow_down=allow_down,
                        input_size=res_size, cell_type=cell_type,
                        cal_flops=cal_flops, using_gate=using_gate,
                        small_gate=small_gate, gate_bias=gate_bias,
                        affine=self.affine
                    )
                )
                # allow dim change in each aggregation
                dim_up, dim_down, dim_keep = False, False, True
                # dim up and resolution down by 2
                if cell_index > 0:
                    dim_up = True
                # dim down and resolution up by 2
                if (cell_index < self.cell_num_list[layer_index] - 1) and layer_index > 2:
                    dim_down = True
                elif (cell_index < self.cell_num_list[layer_index] - 2) and layer_index <= 2:
                    dim_down = True
                # dim keep unchanged
                if layer_index <= 2 and cell_index == self.cell_num_list[layer_index] - 1:
                    dim_keep = False
                # allowed cell operations
                layer_cell_type.append([dim_up, dim_keep, dim_down])
                if layer_index == len(self.cell_num_list) - 1:
                    name = 'layer_' + str(cell_index)
                    self._out_feature_strides[name] = channel_multi * self.stem.stride
                    self._out_feature_channels[name] = in_channel_cell
                    self._out_feature_resolution[name] = res_size
                    self._out_features.append(name)
            self.all_cell_list.append(layer_cell_list)
            self.all_cell_type_list.append(layer_cell_type)

    @property
    def size_divisibility(self):
        return self._size_divisibility

    def forward(self, x, step_rate=0.0):
        h_l1 = self.stem(x)
        # the initial layer
        h_l1_list, h_beta_list, trans_flops, trans_flops_real = self.init_layer(h_l1=h_l1)
        prev_beta_list, prev_out_list = [h_beta_list], [h_l1_list]  # noqa: F841
        prev_trans_flops, prev_trans_flops_real = [trans_flops], [trans_flops_real]
        # build forward outputs
        cell_flops_list, cell_flops_real_list = [], []
        for layer_index in range(len(self.cell_num_list)):
            layer_input, layer_output = [], []
            layer_trans_flops, layer_trans_flops_real = [], []
            flops_in_expt_list, flops_in_real_list = [], []
            layer_rate = (layer_index + 1) / float(len(self.cell_num_list))
            # aggregate cell input
            for cell_index in range(len(self.all_cell_type_list[layer_index])):
                cell_input, trans_flops_input, trans_flops_real_input = [], [], []
                if self.all_cell_type_list[layer_index][cell_index][0]:
                    cell_input.append(prev_out_list[cell_index - 1][2][0])
                    trans_flops_input.append(prev_trans_flops[cell_index - 1][2][0])
                    trans_flops_real_input.append(prev_trans_flops_real[cell_index - 1][2][0])
                if self.all_cell_type_list[layer_index][cell_index][1]:
                    cell_input.append(prev_out_list[cell_index][1][0])
                    trans_flops_input.append(prev_trans_flops[cell_index][1][0])
                    trans_flops_real_input.append(prev_trans_flops_real[cell_index][1][0])
                if self.all_cell_type_list[layer_index][cell_index][2]:
                    cell_input.append(prev_out_list[cell_index + 1][0][0])
                    trans_flops_input.append(prev_trans_flops[cell_index + 1][0][0])
                    trans_flops_real_input.append(prev_trans_flops_real[cell_index + 1][0][0])

                h_l1 = sum(cell_input)
                # calculate input for gate
                layer_input.append(h_l1)
                # calculate FLOPs input
                flops_in_expt = sum(_flops for _flops in trans_flops_input)
                flop_in_real = sum(_flops for _flops in trans_flops_real_input)
                flops_in_expt_list.append(flops_in_expt)
                flops_in_real_list.append(flop_in_real)

            # calculate each cell
            for _cell_index in range(len(self.all_cell_type_list[layer_index])):
                if self.cal_flops:
                    cell_output, gate_weights_beta, cell_flops, \
                        cell_flops_real, trans_flops, trans_flops_real = \
                        self.all_cell_list[layer_index][_cell_index](
                            h_l1=layer_input[_cell_index],
                            flops_in_expt=flops_in_expt_list[_cell_index],
                            flops_in_real=flops_in_real_list[_cell_index],
                            is_drop_path=self.drop_path, drop_prob=self.drop_prob,
                            layer_rate=layer_rate, step_rate=step_rate
                        )
                    # calculate real flops
                    cell_flops_list.append(cell_flops)
                    cell_flops_real_list.append(cell_flops_real)
                else:
                    cell_output, gate_weights_beta, trans_flops, trans_flops_real = \
                        self.all_cell_list[layer_index][_cell_index](
                            h_l1=layer_input[_cell_index],
                            flops_in_expt=flops_in_expt_list[_cell_index],
                            flops_in_real=flops_in_real_list[_cell_index],
                            is_drop_path=self.drop_path, drop_prob=self.drop_prob,
                            layer_rate=layer_rate, step_rate=step_rate
                        )

                layer_output.append(cell_output)
                # update trans flops output
                layer_trans_flops.append(trans_flops)
                layer_trans_flops_real.append(trans_flops_real)
            # update layer output
            prev_out_list = layer_output
            prev_trans_flops = layer_trans_flops
            prev_trans_flops_real = layer_trans_flops_real

        final_out_list = [prev_out_list[_i][1][0] for _i in range(len(prev_out_list))]
        final_out_dict = dict(zip(self._out_features, final_out_list))
        if self.cal_flops:
            all_cell_flops = torch.mean(sum(cell_flops_list))
            all_flops_real = torch.mean(sum(cell_flops_real_list)) + self.stem_flops
        else:
            all_cell_flops, all_flops_real = None, None
        return final_out_dict, all_cell_flops, all_flops_real

    def output_shape(self):
        return {
            name: ShapeSpec(
                channels=self._out_feature_channels[name],
                height=self._out_feature_resolution[name][0],
                width=self._out_feature_resolution[name][0],
                stride=self._out_feature_strides[name]
            )
            for name in self._out_features
        }
dynamic_cell.py

dl_lib/modeling/dynamic_arch/dynamic_cell.py定义了模型中Cell单元的计算流程。
首先根据 G s l = F ( w s , 2 l , G ( σ ( N ( F ( w s , 1 l , X s l ) ) ) ) ) + β s l G^l_s = \mathcal{F}(\mathcal{w}_{s,2}^l,\mathcal{G}(σ(\mathcal{N}(\mathcal{F}(w_{s,1}^l, X^l_s))))) + β_s^l Gsl=F(ws,2l,G(σ(N(F(ws,1l,Xsl)))))+βsl计算soft gate α s l \alpha^l_s αsl
如果 α s l \alpha^l_s αsl足够小,将当前cell的相关值置零;
否则使用 C e l l O p e r a t i o n Cell Operation CellOperation( S e p C o n v 3 × 3 SepConv3\times 3 SepConv3×3,深度可分离卷积, C o n v 3 × 3 + C o n v 1 × 1 + C o n v 3 × 3 + C o n v 1 × 1 Conv3\times 3+Conv1\times 1+Conv3\times 3+Conv1\times 1 Conv3×3+Conv1×1+Conv3×3+Conv1×1)从输入数据中提取特征 H s l H_s^l Hsl,然后对当前层的特征 H s l H_s^l Hsl分别进行上采样、保持尺度、下采样的操作,然后计算当前单元的输出 Y s j = α s → j l T s → j ( H s l ) Y_s^j=\alpha_{s→j}^l\mathcal{T}_{s→j}(H_s^l) Ysj=αsjlTsj(Hsl)

class Cell(nn.Module):
    def __init__(
        self, C_in, C_out, norm, allow_up, allow_down, input_size,
        cell_type, cal_flops=True, using_gate=False,
        small_gate=False, gate_bias=1.5, affine=True
    ):
        super(Cell, self).__init__()
        self.channel_in = C_in
        self.channel_out = C_out
        self.allow_up = allow_up
        self.allow_down = allow_down
        self.cal_flops = cal_flops
        self.using_gate = using_gate
        self.small_gate = small_gate

        self.cell_ops = Mixed_OP(
            inplanes=self.channel_in, outplanes=self.channel_out,
            stride=1, cell_type=cell_type, norm=norm,
            affine=affine, input_size=input_size
        )
        self.cell_flops = self.cell_ops.flops
        # resolution keep
        self.res_keep = nn.ReLU()
        self.res_keep_flops = cal_op_flops.count_ReLU_flop(
            input_size[0], input_size[1], self.channel_out
        )
        # resolution up and dim down
        if self.allow_up:
            self.res_up = nn.Sequential(
                nn.ReLU(),
                Conv2d(
                    self.channel_out, self.channel_out // 2, kernel_size=1,
                    stride=1, padding=0, bias=False,
                    norm=get_norm(norm, self.channel_out // 2),
                    activation=nn.ReLU()
                )
            )
            # calculate Flops
            self.res_up_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1], self.channel_out
            ) + cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_out,
                self.channel_out // 2, [1, 1], is_affine=affine
            )
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_up, mode='fan_in')
        # resolution down and dim up
        if self.allow_down:
            self.res_down = nn.Sequential(
                nn.ReLU(),
                Conv2d(
                    self.channel_out, 2 * self.channel_out,
                    kernel_size=1, stride=2, padding=0, bias=False,
                    norm=get_norm(norm, 2 * self.channel_out),
                    activation=nn.ReLU()
                )
            )
            # calculate Flops
            self.res_down_flops = cal_op_flops.count_ReLU_flop(
                input_size[0], input_size[1], self.channel_out
            ) + cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_out,
                2 * self.channel_out, [1, 1], stride=2, is_affine=affine
            )
            # using Kaiming init
            weight_init.kaiming_init_module(self.res_down, mode='fan_in')
        if self.allow_up and self.allow_down:
            self.gate_num = 3
        elif self.allow_up or self.allow_down:
            self.gate_num = 2
        else:
            self.gate_num = 1
        if self.using_gate:
            self.gate_conv_beta = nn.Sequential(
                Conv2d(
                    self.channel_in, self.channel_in // 2, kernel_size=1,
                    stride=1, padding=0, bias=False,
                    norm=get_norm(norm, self.channel_in // 2),
                    activation=nn.ReLU()
                ),
                nn.AdaptiveAvgPool2d((1, 1)),
                Conv2d(
                    self.channel_in // 2, self.gate_num, kernel_size=1,
                    stride=1, padding=0, bias=True
                )
            )
            if self.small_gate:
                input_size = input_size // 4
            self.gate_flops = cal_op_flops.count_ConvBNReLU_flop(
                input_size[0], input_size[1], self.channel_in,
                self.channel_in // 2, [1, 1], is_affine=affine
            ) + cal_op_flops.count_Pool2d_flop(
                input_size[0], input_size[1], self.channel_in // 2, [1, 1], 1
            ) + cal_op_flops.count_Conv_flop(
                1, 1, self.channel_in // 2, self.gate_num, [1, 1]
            )
            # using Kaiming init and predefined bias for gate
            weight_init.kaiming_init_module(
                self.gate_conv_beta, mode='fan_in', bias=gate_bias
            )
        else:
            self.register_buffer(
                'gate_weights_beta', torch.ones(1, self.gate_num, 1, 1).cuda()
            )
            self.gate_flops = 0.0

    def forward(
        self, h_l1, flops_in_expt=None, flops_in_real=None,
        is_drop_path=False, drop_prob=0.0,
        layer_rate=0.0, step_rate=0.0
    ):
        """
        :param h_l1: # the former hidden layer output
        :return: current hidden cell result h_l
        """
        drop_cell = False
        # drop the cell if input type is float
        if not isinstance(h_l1, float):
            # calculate soft conditional gate
            if self.using_gate:
                if self.small_gate:
                    h_l1_gate = F.interpolate(
                        input=h_l1, scale_factor=0.25,
                        mode='bilinear', align_corners=False
                    )
                else:
                    h_l1_gate = h_l1
                gate_feat_beta = self.gate_conv_beta(h_l1_gate)
                gate_weights_beta = soft_gate(gate_feat_beta)
            else:
                gate_weights_beta = self.gate_weights_beta
        else:
            drop_cell = True
        # use for inference
        if not self.training:
            if not drop_cell:
                drop_cell = gate_weights_beta.sum() < 0.0001
            if drop_cell:
                result_list = [[0.0], [h_l1], [0.0]]
                weights_list_beta = [[0.0], [0.0], [0.0]]
                trans_flops_expt = [[0.0], [0.0], [0.0]]
                trans_flops_real = [[0.0], [0.0], [0.0]]
                if self.cal_flops:
                    h_l_flops = flops_in_expt
                    h_l_flops_real = flops_in_real + self.gate_flops
                    return (
                        result_list, weights_list_beta, h_l_flops,
                        h_l_flops_real, trans_flops_expt, trans_flops_real
                    )
                else:
                    return (
                        result_list, weights_list_beta,
                        trans_flops_expt, trans_flops_real
                    )

        h_l = self.cell_ops(h_l1, is_drop_path, drop_prob, layer_rate, step_rate)

        # resolution and dimension change
        # resolution: [up, keep, down]
        h_l_keep = self.res_keep(h_l)
        gate_weights_beta_keep = gate_weights_beta[:, 0].unsqueeze(-1)
        # using residual connection if drop cell
        gate_mask = (gate_weights_beta.sum(dim=1, keepdim=True) < 0.0001).float()
        result_list = [[], [gate_mask * h_l1 + gate_weights_beta_keep * h_l_keep], []]
        weights_list_beta = [[], [gate_mask * 1.0 + gate_weights_beta_keep], []]
        # calculate flops for keep res
        gate_mask_keep = (gate_weights_beta_keep > 0.0001).float()
        trans_flops_real = [[], [gate_mask_keep * self.res_keep_flops], []]
        # calculate trans flops
        trans_flops_expt = [[], [self.res_keep_flops * gate_weights_beta_keep], []]

        if self.allow_up:
            h_l_up = self.res_up(h_l)
            h_l_up = F.interpolate(
                input=h_l_up, scale_factor=2, mode='bilinear', align_corners=False
            )
            gate_weights_beta_up = gate_weights_beta[:, 1].unsqueeze(-1)
            result_list[0].append(h_l_up * gate_weights_beta_up)
            weights_list_beta[0].append(gate_weights_beta_up)
            trans_flops_expt[0].append(self.res_up_flops * gate_weights_beta_up)
            # calculate flops for up res
            gate_mask_up = (gate_weights_beta_up > 0.0001).float()
            trans_flops_real[0].append(gate_mask_up * self.res_up_flops)

        if self.allow_down:
            h_l_down = self.res_down(h_l)
            gate_weights_beta_down = gate_weights_beta[:, -1].unsqueeze(-1)
            result_list[2].append(h_l_down * gate_weights_beta_down)
            weights_list_beta[2].append(gate_weights_beta_down)
            trans_flops_expt[2].append(self.res_down_flops * gate_weights_beta_down)
            # calculate flops for down res
            gate_mask_down = (gate_weights_beta_down > 0.0001).float()
            trans_flops_real[2].append(gate_mask_down * self.res_down_flops)

        if self.cal_flops:
            cell_flops = gate_weights_beta.max(dim=1, keepdim=True)[0] * self.cell_flops
            cell_flops_real = (
                gate_weights_beta.sum(dim=1, keepdim=True) > 0.0001
            ).float() * self.cell_flops
            h_l_flops = cell_flops + flops_in_expt
            h_l_flops_real = cell_flops_real + flops_in_real + self.gate_flops
            return (
                result_list, weights_list_beta, h_l_flops,
                h_l_flops_real, trans_flops_expt, trans_flops_real
            )
        else:
            return result_list, weights_list_beta, trans_flops_expt, trans_flops_real

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值