在DTT的框架下修改模型
我们首先要了解一下模型是怎么加载的:
在 PyTorch Lightning 中,
LightningModule
是一个封装了 PyTorch 模型的类,它提供了一个简化的训练、验证和测试流程的接口。Hydra 是一个配置管理工具,它可以帮助你管理复杂的配置,并在运行时动态地组合它们。当你在代码中看到:
model: LightningModule = hydra.utils.instantiate(cfg.model)
这行代码做了以下几件事情:
- 从配置中读取模型配置:
cfg.model
是一个配置对象(通常是DictConfig
类型),它包含了模型的配置信息。这个配置信息可能包括模型的类名、初始化参数等。
configs\model\vocals.yaml
文件:_target_: src.dp_tdf.dp_tdf_net.DPTDFNet target_name: 'vocals' lr: 0.0001 optimizer: adamW #####此处省略一大堆模型参数
实例化模型:
hydra.utils.instantiate
是一个函数,它根据提供的配置对象创建一个实例。具体来说,它查找配置对象中的_target_
键,这个键指定了要实例化的类的全名(包括模块路径)。返回一个
LightningModule
实例:实例化完成后,函数返回一个LightningModule
的实例,这个实例是根据配置文件中指定的参数创建的。详细步骤
- 定义配置: 假设你有一个配置文件
config.yaml
,其中定义了模型的配置:yamlmodel: _target_: my_module.MyModel param1: value1 param2: value2
- 加载配置: 在你的代码中,你首先需要加载这个配置文件:
from omegaconf import OmegaConf import hydra cfg = hydra.compose("configs/config.yaml")
或者用作装饰器
@hydra.main(config_path="configs/", config_name="config.yaml")
- 实例化模型: 使用 Hydra 的
instantiate
函数来创建模型实例:from pytorch_lightning import LightningModule import hydra model: LightningModule = hydra.utils.instantiate(cfg.model)
在这个例子中,
cfg.model
包含了以下内容:python{ '_target_': 'my_module.MyModel', 'param1': 'value1', 'param2': 'value2' }
hydra.utils.instantiate
函数会:
- 查找
_target_
键,找到my_module.MyModel
。- 将
param1
和param2
作为参数传递给MyModel
的构造函数。
- 使用模型: 实例化后,你可以像使用普通的 PyTorch Lightning 模型一样使用这个模型:
pythontrainer = pl.Trainer(...) trainer.fit(model)
注意事项
- 确保
_target_
键的值是有效的 Python 类的全名(包括模块路径)。- 确保传递给
instantiate
函数的配置对象是DictConfig
类型,这是 OmegaConf 库中用于处理配置的类型。- 如果你的模型类需要一些特定的初始化参数,确保这些参数在配置文件中正确定义。
通过这种方式,Hydra 和 PyTorch Lightning 可以无缝集成,使得模型的配置和实例化过程更加灵活和可维护。
修改模型
我想加上时域,所以:
- 修改
src\dp_tdf\abstract.py
里的def training_step(self, *args, **kwargs) -> STEP_OUTPUT: mix_wave, target_wave = args[0] # (batch, c, 261120) # input 1 stft_44k = self.stft(mix_wave) # (batch, c*2, 1044, 256) # forward t_est_stft = self(stft_44k) # (batch, c, 1044, 256) loss = self.comp_loss(t_est_stft, target_wave) self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True) return {"loss": loss}
为
def training_step(self, *args, **kwargs) -> STEP_OUTPUT: mix_wave, target_wave = args[0] # (batch, c, 261120) pred_wave = self(mix_wave) loss = self.comp_loss(pred_wave, target_wave) self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True) return {"loss": loss}
(修改后的模型的输入和输出均为waveform)(传入损失函数的参数也是waveform)
修改
src\dp_tdf\abstract.py
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: mix_chunk_batches, target = args[0] ######省略########### for batch in mix_chunk_batches: # input stft_44k = self.stft(batch) # (batch, c*2, 1044, 256) pred_detail = self(stft_44k) # (batch, c, 1044, 256), irm pred_detail = self.istft(pred_detail) target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap]) target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t) ########省略##############3333
为
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: mix_chunk_batches, target = args[0] ######省略########### for batch in mix_chunk_batches: # input pred_detail = self(batch) target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap]) target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t) ########省略##############3333
- 修改
src\dp_tdf\abstract.py
里的损失函数def comp_loss(self, pred_detail, target_wave): pred_detail = self.istft(pred_detail) comp_loss = F.l1_loss(pred_detail, target_wave) self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False) return comp_loss
为
def comp_loss(self, pred_wave, target_wave): comp_loss = F.l1_loss(pred_wave, target_wave) self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False) return comp_loss
增加
src\dp_tdf\my_net.py
文件,定义了MYNET这个模型增加
configs\model\vocals_test.yaml
文件,相较于vocals.yaml文件修改了_target_
为src.dp_tdf.my_net.MYNET
,并把g修改为64修改了
configs\experiment\vocals_dis.yaml
文件,把defaults: - multigpu_default - override /model: vocals.yaml
改成
defaults: - multigpu_default - override /model: vocals_test.yaml
并把
g
由32改成64
在
src\dp_tdf\modules.py
添加t_TFC和t_TFC_LSTM这两个类定义在
src\dp_tdf\bandsequence.py
中第80行把group_num = input_dim_size // 16
改成group_num = input_dim_size // 23
(因为input_dim_size必须可以被group_num整除,我这边模型修改之后因为把waveform和频谱在通道上拼在一起,所以通道数由原来的
3*64
变成了3*115
,16就除不尽了)同理,n_heads原来是2,现在除不尽了,要改成3(在
configs\model\vocals_test.yaml
里修改)
- 先看看
DTTNet-Pytorch/configs/experiment/vocals_dis.yaml
中的trainer部分配置之前加的resume_from_checkpoint要删掉
- 执行命令,开始训练
python train.py experiment=vocals_dis datamodule=musdb_dev14 trainer=default
- 发现显存不够,于是把batch从4改成1,相应的,学习率要调小(我之前把batch从8改成4的时候就忘了把lr从0.0002调成0.0001,所以现在要从0.0002除以8,即把lr改成0.000025
评价自己的模型
-
先修改一下
src\evaluation\separate.py
的185行附近:with torch.no_grad(): model.eval() for mixture_wav in mix_waves_batched: mix_spec = model.stft(mixture_wav.to(device)) spec_hat = model(mix_spec) target_wav_hat = model.istft(spec_hat) target_wav_hat = target_wav_hat.cpu().detach().numpy() target_wav_hats.append(target_wav_hat) # (b, c, t)
改成
with torch.no_grad(): model.eval() for mixture_wav in mix_waves_batched: mixture_wav=mixture_wav.to(device) target_wav_hat= model(mixture_wav) target_wav_hat = target_wav_hat.cpu().detach().numpy() target_wav_hats.append(target_wav_hat) # (b, c, t)
(一定要
x=x.to(device)
哦,光光x.to(device)
没用的) -
在
configs\evaluation.yaml
把batch从4改成1?(反正我改成1了) -
这个很重要:把
configs\model\vocals_test.yaml
里的bn_norm
从BN
改成syncBN
其实在训练时就该改了,不过训练时的配置文件比较复杂,
configs\experiment\vocals_dis.yaml
把bn_norm设置为syncBN并覆盖了原来的设置BN。到了评分时,没用到这个配置文件,就需要改
onfigs\model\vocals_test.yaml
里的bn_norm了那为什么BN不行呢?
batchnorm的源码开头检查了维度必须为4,syncbatchnorm源码就没这个检查
export ckpt_path="/home/wujunyu/DTTNet-Pytorch/check_points/vocals_vocals_g32_38/checkpoints/epoch=04-step=4050.ckpt"
####上面这个是作者提供的训练好的参数
python run_eval.py model=vocals_test logger.wandb.name=evaluate_new_model_1.0_with_too_high_lr
###注意这个xxxx应替换为一个实际的名称(帮助你识别和管理不同的训练或评估运行)