(深度学习入门)如何根据官方代码创建模型(以fasternet为例)

首先去git上download整个zip包并解压缩。

JierunChen/FasterNet: [CVPR 2023] Code for PConv and FasterNet (github.com)

然后注意查看下面的readme,里面有给出各种类型的预训练模型,可以下载你所需要的类型。

这里我需要使用fasternet_s,就下载其对应的pth文件即可。

然后把该文件放到XX/user/.cache/torch/hub/checkpoints目录下,这里是由于fasternet是timm库里的模型,想用timm库来创建该模型,因此放到这里。

接下来打开解压缩后的项目文件,可以看到在cfg目录下有不同类型模型对应的参数详情,这里fasternet_s对应的是fasternet_s.yaml文件,打开后可以看到里面写的模型参数配置。

接下来在models文件夹下打开fasternet.py文件,可以看到该文件中只有关于FasterNet的定义,但是没有关于FasterNet_S模型的调用函数,因此我们这里需要新写一个FasterNet_S函数来方便后续的模型创建。

在编写该函数时,参数的顺序和样式都直接参照FasterNet的init函数,具体的参数值如果在前面yaml文件中有,就直接写文件中的值;如果文件中没提到该参数的值,就直接参照FasterNet的init函数中的值写即可。

from timm.models.registry import register_model
@register_model
def fasternet_s(pretrained=False,**kwargs):
    model=FasterNet(
        in_chans=3,
        num_classes=1000,
        embed_dim=128,
        depths=(1,2,13,2),
        mlp_ratio=2.,
        n_div=4,
        patch_size=4,
        patch_stride=4,
        patch_size2=2,
        patch_stride2=2,
        patch_norm=True,
        feature_dim=1280,
        drop_path_rate=0.1,
        layer_scale_init_value=0,
        norm_layer='BN',
        act_layer='RELU',
        fork_feat=False,
        init_cfg=None,
        pretrained=pretrained,
        pconv_fw_type='split_cat',
        **kwargs
    )
    if pretrained:
        model_key = 'fasternet_s'
        url = model_urls[model_key]
        import timm
        model=timm.models.create_model('fasternet_s')
    return model

编写好该函数后,即可通过调用函数来创建fasternet_s模型。

from models.fasternet import fasternet_s
model_ft = fasternet_s(pretrained=True)
model_ft.to(DEVICE)

  • 9
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值