替换~/clarity/recipes/cad_icassp_2024/baseline/enhance.py里用的demucs模型为htdemucs模型
换模型一般用这个函数
torch.hub.load
是 PyTorch 用于从模型库(如 GitHub 仓库)加载预训练模型的函数。其参数如下:
torch.hub.load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, **kwargs)
各参数详细说明如下:
- repo_or_dir (str): 指定模型库的 GitHub 仓库地址(格式为
repo_owner/repo_name[:ref]
)或本地目录路径。例如pytorch/vision:v0.10.0
。 - model (str): 要加载的模型名称。例如
resnet50
。 - *args: 传递给模型构造函数的其他位置参数。
- source (str, optional): 模型库的来源。默认值为
'github'
,可以是'local'
来表示本地目录。 - force_reload (bool, optional): 是否强制重新下载模型。默认值为
False
。 - verbose (bool, optional): 是否显示详细信息。默认值为
True
。 - skip_validation (bool, optional): 是否跳过对 repo_or_dir 参数的验证。默认值为
False
。 - kwargs: 传递给模型构造函数的其他关键字参数。
我遇到的实际情况:想添加一个htdemucs的模型
# Loading pretrained source separation model
if config.separator.model == "demucs":
separation_model = HDEMUCS_HIGH_MUSDB.get_model()
model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate
sources_order = separation_model.sources
normalise = True
elif config.separator.model == "openunmix":
separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0)
model_sample_rate = separation_model.sample_rate
sources_order = ["vocals", "drums", "bass", "other"]
normalise = False
elif config.separator.model=="htdemucs":
separation_model = torch.hub.load("/home/wujunyu/clarity/recipes/cad_icassp_2024/baseline/demucs/", "get_model","htdemucs",source='local')
model_sample_rate = separation_model.samplerate
sources_order = separation_model.sources
normalise = True
else:
raise ValueError(f"Separator model {config.separator.model} not supported.")
关于torch.hub.load的参数有不少需要注意的点:
首先要知道
就是这个函数是在/home/wujunyu/miniconda3/envs/clarity/lib/python3.8/site-packages/torch/hub.py
定义的,是属于torch这个包的。
然后什么样的模型可以以这种方式加载呢?你去GitHub上看,需要有一个文件叫hubconf.py
在torch/hub.py里面,torch.hub.load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, **kwargs)
这个函数先检查看repo_or_dir下面有没有一个叫hubconf.py的文件,必须有(如果想换一个模型,然后这个模型的GitHub上没有这个hubconf.py,或许可以先下载到本地然后自己添加一个?),然后检查hubconf.py里的dependencies列表列出的依赖是不是都已经下载了(这边依赖的名字可能会有问题,比如dora写成dora-search,就是名字的问题,包其实没问题,需要手动修改一下),然后model这个参数必须是在hubconf.py被import了的,model其实是一个函数啦,然后就会调用model(*args, **kwargs)。