FBNet代码理解

FBNet_A网络参数
"fbnet_a": {
        "input_size": 224,
        "blocks": [
            # [op, c, s, n, ...]
            # stage 0
            [["conv_k3", 16, 2, 1]],  # 卷积通道数,步长,重复次数  3x3 conv
            # stage 1
            [["skip", 16, 1, 1]],  # TBS1
            # stage 2
            [
                ["ir_k3", 24, 2, 1, e3],  # 卷积通道数,步长,重复次数,扩张率
                ["ir_k3", 24, 1, 1, e1],
                ["skip", 24, 1, 1],
                ["skip", 24, 1, 1],  # TBS2
            ],
            # stage 3
            [
                ["ir_k5", 32, 2, 1, e6],
                ["ir_k3", 32, 1, 1, e3],
                ["ir_k5", 32, 1, 1, e1],
                ["ir_k3", 32, 1, 1, e3],  # TBS3
            ],
            # stage 4
            [
                ["ir_k5", 64, 2, 1, e6],
                ["ir_k5", 64, 1, 1, e3],
                ["ir_k5_g2", 64, 1, 1, e1],
                ["ir_k5", 64, 1, 1, e6],  # TBS4
                ["ir_k3", 112, 1, 1, e6],
                ["ir_k5_g2", 112, 1, 1, e1],
                ["ir_k5", 112, 1, 1, e3],
                ["ir_k3_g2", 112, 1, 1, e1],  # TBS5
            ],
            # stage 5
            [
                ["ir_k5", 184, 2, 1, e6],
                ["ir_k5", 184, 1, 1, e6],
                ["ir_k5", 184, 1, 1, e3],
                ["ir_k5", 184, 1, 1, e6],  # TBS6
                ["ir_k5", 352, 1, 1, e6],  # TBS7
            ],
            # stage 5
            [("conv_k1", 1504, 1, 1)],
        ],
    },

具体各层参数:
参数解析结果

class FBNet(nn.Module):
    def __init__(self, arch_name, dim_in=3, num_classes=1000):
        super().__init__()
        self.backbone = FBNetBackbone(arch_name, dim_in)  # 基本网络
        self.head = ClsConvHead(self.backbone.out_channels, num_classes) # 分类器

    def forward(self, x):
        y = self.backbone(x)
        y = self.head(y)
        return y

    @property
    def arch_def(self):
        return self.backbone.arch_def

class FBNetBackbone(nn.Module):
    def __init__(self, arch_name, dim_in=3):
        # arch_name='fbnet_a'
        super().__init__()

        builder, arch_def = _create_builder(arch_name)  # 构建器初始化,模型各层参数

        self.stages = builder.build_blocks(arch_def["blocks"], dim_in=dim_in)  # 开始建立各层
        self.dropout = misc.add_dropout(arch_def["dropout_ratio"])  # dropout层参数0.2
        self.out_channels = builder.last_depth  # 1504 最后一层通道数
        self.arch_def = arch_def

    def forward(self, x):
        y = self.stages(x)
        if self.dropout is not None:
            y = self.dropout(y)
        return y
def _create_builder(arch_name_or_def: typing.Union[str, dict]):
    if isinstance(arch_name_or_def, str):
        assert arch_name_or_def in modeldef.MODEL_ARCH, (
            f"Invalid arch name {arch_name_or_def}, "
            f"available names: {modeldef.MODEL_ARCH.keys()}"
        )
        arch_def = modeldef.MODEL_ARCH[arch_name_or_def]
    else:
        assert isinstance(arch_name_or_def, dict)
        arch_def = arch_name_or_def

    arch_def = mbuilder.unify_arch_def(arch_def, ["blocks"])  # 各层参数解析

    scale_factor = 1.0  # 尺度因子
    width_divisor = 1  # 通道数被除数
    bn_info = {"name": "bn", "momentum": 0.003}  # bn层参数
    drop_out = 0.2  # dropout层参数

    arch_def["dropout_ratio"] = drop_out

    builder = mbuilder.FBNetBuilder(
        width_ratio=scale_factor, bn_args=bn_info, width_divisor=width_divisor
    )  # 构建器
    builder.add_basic_args(**arch_def.get("basic_args", {}))  # 添加参数{}

    return builder, arch_def
class FBNetBuilder(object):
    def __init__(self, width_ratio=1.0, bn_args="bn", width_divisor=1):
        # 宽度因子=1, bn层={'name': 'bn', 'momentum': 0.003},通道数被除数=1
        self.width_ratio = width_ratio  #
        self.last_depth = -1
        self.width_divisor = width_divisor
        # basic arguments that will be provided to all primitivies, they could be
        #   overrided by primitive parameters
        self.basic_args = {
            "bn_args": hp.unify_args(bn_args),
            "width_divisor": width_divisor,
        }  # 基本参数 {'bn_args': {'name': 'bn', 'momentum': 0.003}, 'width_divisor': 1}

    def add_basic_args(self, **kwargs):
        # 添加基本参数
        """ args that will be passed to all primitives, they could be
              overrided by primitive parameters
        """
        hp.update_dict(self.basic_args, kwargs)

    def build_blocks(
        self,
        blocks,
        stage_indices=None,
        dim_in=None,
        prefix_name="xif",
        **kwargs,
    ):
        """ blocks: [{}, {}, ...]

        Inputs: (list(int)) stages to add
                (list(int)) if block[0] is not connected to the most
                            recently added block, list specifies the input
                            dimensions of the blocks (as self.last_depth
                            will be inaccurate)
        """
        assert isinstance(blocks, list) and all(
            isinstance(x, dict) for x in blocks
        ), blocks

        if stage_indices is not None:
            blocks = [x for x in blocks if x["stage_idx"] in stage_indices]

        if dim_in is not None:
            self.last_depth = dim_in
        assert (
            self.last_depth != -1
        ), "Invalid input dimension. Pass `dim_in` to `add_blocks`."

        modules = OrderedDict()
        for block in blocks:
            stage_idx = block["stage_idx"]
            block_idx = block["block_idx"]
            block_op = block["block_op"]
            block_cfg = block["block_cfg"]
            cur_kwargs = update_with_block_kwargs(copy.deepcopy(kwargs), block)
            nnblock = self.build_block(
                block_op, block_cfg, dim_in=None, **cur_kwargs
            )
            nn_name = f"{prefix_name}{stage_idx}_{block_idx}"
            assert nn_name not in modules
            modules[nn_name] = nnblock
        ret = nn.Sequential(modules)
        ret.out_channels = self.last_depth
        return ret

    def build_block(self, block_op, block_cfg, dim_in=None, **kwargs):
        if dim_in is None:
            dim_in = self.last_depth
        assert "out_channels" in block_cfg
        block_cfg = copy.deepcopy(block_cfg)
        out_channels = block_cfg.pop("out_channels")
        out_channels = self._get_divisible_width(
            out_channels * self.width_ratio
        )
        # dicts appear later will override the configs in the earlier ones
        new_kwargs = hp.get_merged_dict(self.basic_args, block_cfg, kwargs)
        ret = PRIMITIVES.get(block_op)(dim_in, out_channels, **new_kwargs)
        self.last_depth = getattr(ret, "out_channels", out_channels)
        return ret

    def _get_divisible_width(self, width):
        ret = hp.get_divisible_by(
            int(width), self.width_divisor, self.width_divisor
        )
        return ret

论文链接
代码链接

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值