首先去git上download整个zip包并解压缩。
JierunChen/FasterNet: [CVPR 2023] Code for PConv and FasterNet (github.com)
然后注意查看下面的readme,里面有给出各种类型的预训练模型,可以下载你所需要的类型。
这里我需要使用fasternet_s,就下载其对应的pth文件即可。
然后把该文件放到XX/user/.cache/torch/hub/checkpoints目录下,这里是由于fasternet是timm库里的模型,想用timm库来创建该模型,因此放到这里。
接下来打开解压缩后的项目文件,可以看到在cfg目录下有不同类型模型对应的参数详情,这里fasternet_s对应的是fasternet_s.yaml文件,打开后可以看到里面写的模型参数配置。
接下来在models文件夹下打开fasternet.py文件,可以看到该文件中只有关于FasterNet的定义,但是没有关于FasterNet_S模型的调用函数,因此我们这里需要新写一个FasterNet_S函数来方便后续的模型创建。
在编写该函数时,参数的顺序和样式都直接参照FasterNet的init函数,具体的参数值如果在前面yaml文件中有,就直接写文件中的值;如果文件中没提到该参数的值,就直接参照FasterNet的init函数中的值写即可。
from timm.models.registry import register_model
@register_model
def fasternet_s(pretrained=False,**kwargs):
model=FasterNet(
in_chans=3,
num_classes=1000,
embed_dim=128,
depths=(1,2,13,2),
mlp_ratio=2.,
n_div=4,
patch_size=4,
patch_stride=4,
patch_size2=2,
patch_stride2=2,
patch_norm=True,
feature_dim=1280,
drop_path_rate=0.1,
layer_scale_init_value=0,
norm_layer='BN',
act_layer='RELU',
fork_feat=False,
init_cfg=None,
pretrained=pretrained,
pconv_fw_type='split_cat',
**kwargs
)
if pretrained:
model_key = 'fasternet_s'
url = model_urls[model_key]
import timm
model=timm.models.create_model('fasternet_s')
return model
编写好该函数后,即可通过调用函数来创建fasternet_s模型。
from models.fasternet import fasternet_s
model_ft = fasternet_s(pretrained=True)
model_ft.to(DEVICE)