03如何替换baseline的模型

替换~/clarity/recipes/cad_icassp_2024/baseline/enhance.py里用的demucs模型为htdemucs模型

换模型一般用这个函数

torch.hub.load 是 PyTorch 用于从模型库(如 GitHub 仓库)加载预训练模型的函数。其参数如下:

torch.hub.load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, **kwargs)

各参数详细说明如下:

  1. repo_or_dir (str): 指定模型库的 GitHub 仓库地址(格式为 repo_owner/repo_name[:ref])或本地目录路径。例如 pytorch/vision:v0.10.0
  2. model (str): 要加载的模型名称。例如 resnet50
  3. *args: 传递给模型构造函数的其他位置参数。
  4. source (str, optional): 模型库的来源。默认值为 'github',可以是 'local' 来表示本地目录。
  5. force_reload (bool, optional): 是否强制重新下载模型。默认值为 False
  6. verbose (bool, optional): 是否显示详细信息。默认值为 True
  7. skip_validation (bool, optional): 是否跳过对 repo_or_dir 参数的验证。默认值为 False
  8. kwargs: 传递给模型构造函数的其他关键字参数。

我遇到的实际情况:想添加一个htdemucs的模型

 # Loading pretrained source separation model
    if config.separator.model == "demucs":
        separation_model = HDEMUCS_HIGH_MUSDB.get_model()
        model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate
        sources_order = separation_model.sources
        normalise = True
    elif config.separator.model == "openunmix":
        separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0)
        model_sample_rate = separation_model.sample_rate
        sources_order = ["vocals", "drums", "bass", "other"]
        normalise = False
    elif config.separator.model=="htdemucs":
        separation_model = torch.hub.load("/home/wujunyu/clarity/recipes/cad_icassp_2024/baseline/demucs/", "get_model","htdemucs",source='local')
        model_sample_rate = separation_model.samplerate
        sources_order = separation_model.sources
        normalise = True     
    else:
        raise ValueError(f"Separator model {config.separator.model} not supported.")

关于torch.hub.load的参数有不少需要注意的点:

首先要知道

就是这个函数是在/home/wujunyu/miniconda3/envs/clarity/lib/python3.8/site-packages/torch/hub.py定义的,是属于torch这个包的。

然后什么样的模型可以以这种方式加载呢?你去GitHub上看,需要有一个文件叫hubconf.py在这里插入图片描述

在这里插入图片描述

在torch/hub.py里面,torch.hub.load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, **kwargs)这个函数先检查看repo_or_dir下面有没有一个叫hubconf.py的文件,必须有(如果想换一个模型,然后这个模型的GitHub上没有这个hubconf.py,或许可以先下载到本地然后自己添加一个?),然后检查hubconf.py里的dependencies列表列出的依赖是不是都已经下载了(这边依赖的名字可能会有问题,比如dora写成dora-search,就是名字的问题,包其实没问题,需要手动修改一下),然后model这个参数必须是在hubconf.py被import了的,model其实是一个函数啦,然后就会调用model(*args, **kwargs)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值