09 在DTT的框架下修改模型

在DTT的框架下修改模型

我们首先要了解一下模型是怎么加载的:

在 PyTorch Lightning 中,LightningModule 是一个封装了 PyTorch 模型的类,它提供了一个简化的训练、验证和测试流程的接口。Hydra 是一个配置管理工具,它可以帮助你管理复杂的配置,并在运行时动态地组合它们。

当你在代码中看到:

model: LightningModule = hydra.utils.instantiate(cfg.model)

这行代码做了以下几件事情:

  1. 从配置中读取模型配置cfg.model 是一个配置对象(通常是 DictConfig 类型),它包含了模型的配置信息。这个配置信息可能包括模型的类名、初始化参数等。

configs\model\vocals.yaml文件:

_target_: src.dp_tdf.dp_tdf_net.DPTDFNet

target_name: 'vocals'
lr: 0.0001
optimizer: adamW
#####此处省略一大堆模型参数
  1. 实例化模型hydra.utils.instantiate 是一个函数,它根据提供的配置对象创建一个实例。具体来说,它查找配置对象中的 _target_ 键,这个键指定了要实例化的类的全名(包括模块路径)。

  2. 返回一个 LightningModule 实例:实例化完成后,函数返回一个 LightningModule 的实例,这个实例是根据配置文件中指定的参数创建的。

详细步骤

  1. 定义配置: 假设你有一个配置文件 config.yaml,其中定义了模型的配置:
yamlmodel:
_target_: my_module.MyModel
param1: value1
param2: value2
  1. 加载配置: 在你的代码中,你首先需要加载这个配置文件:
from omegaconf import OmegaConf
import hydra

cfg = hydra.compose("configs/config.yaml")

或者用作装饰器

@hydra.main(config_path="configs/", config_name="config.yaml")
  1. 实例化模型: 使用 Hydra 的 instantiate 函数来创建模型实例:
from pytorch_lightning import LightningModule
import hydra

model: LightningModule = hydra.utils.instantiate(cfg.model)

在这个例子中,cfg.model 包含了以下内容:

python{
'_target_': 'my_module.MyModel',
'param1': 'value1',
'param2': 'value2'
}

hydra.utils.instantiate 函数会:

  • 查找 _target_ 键,找到 my_module.MyModel
  • param1param2 作为参数传递给 MyModel 的构造函数。
  1. 使用模型: 实例化后,你可以像使用普通的 PyTorch Lightning 模型一样使用这个模型:
pythontrainer = pl.Trainer(...)
trainer.fit(model)

注意事项

  • 确保 _target_ 键的值是有效的 Python 类的全名(包括模块路径)。
  • 确保传递给 instantiate 函数的配置对象是 DictConfig 类型,这是 OmegaConf 库中用于处理配置的类型。
  • 如果你的模型类需要一些特定的初始化参数,确保这些参数在配置文件中正确定义。

通过这种方式,Hydra 和 PyTorch Lightning 可以无缝集成,使得模型的配置和实例化过程更加灵活和可维护。

修改模型

我想加上时域,所以:

  1. 修改src\dp_tdf\abstract.py里的
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
  mix_wave, target_wave = args[0] # (batch, c, 261120)

  # input 1
  stft_44k = self.stft(mix_wave) # (batch, c*2, 1044, 256)
  # forward
  t_est_stft = self(stft_44k) # (batch, c, 1044, 256)

  loss = self.comp_loss(t_est_stft, target_wave)

  self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True)

  return {"loss": loss}

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
mix_wave, target_wave = args[0] # (batch, c, 261120)

pred_wave = self(mix_wave)

loss = self.comp_loss(pred_wave, target_wave)

self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True)

return {"loss": loss}

(修改后的模型的输入和输出均为waveform)(传入损失函数的参数也是waveform)

修改src\dp_tdf\abstract.py

 def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
      mix_chunk_batches, target = args[0]

      ######省略###########

      for batch in mix_chunk_batches:
          # input
          stft_44k = self.stft(batch)  # (batch, c*2, 1044, 256)
            pred_detail = self(stft_44k) # (batch, c, 1044, 256), irm
          pred_detail = self.istft(pred_detail)

          target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap])
      target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t)

      ########省略##############3333

    def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
          mix_chunk_batches, target = args[0]

        ######省略###########

          for batch in mix_chunk_batches:
            # input
            pred_detail = self(batch)

              target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap])
        target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t)

        ########省略##############3333
  1. 修改src\dp_tdf\abstract.py里的损失函数
def comp_loss(self, pred_detail, target_wave):
     pred_detail = self.istft(pred_detail)

  comp_loss = F.l1_loss(pred_detail, target_wave)

  self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False)

   return comp_loss

def comp_loss(self, pred_wave, target_wave):

  comp_loss = F.l1_loss(pred_wave, target_wave)

  self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False)

   return comp_loss
  1. 增加src\dp_tdf\my_net.py文件,定义了MYNET这个模型

  2. 增加configs\model\vocals_test.yaml文件,相较于vocals.yaml文件修改了_target_src.dp_tdf.my_net.MYNET,并把g修改为64

  3. 修改了configs\experiment\vocals_dis.yaml文件,把

defaults:
- multigpu_default
- override /model: vocals.yaml

改成

defaults:
- multigpu_default
- override /model: vocals_test.yaml

并把g由32改成64

  1. src\dp_tdf\modules.py添加t_TFC和t_TFC_LSTM这两个类定义

  2. src\dp_tdf\bandsequence.py中第80行把group_num = input_dim_size // 16改成group_num = input_dim_size // 23

(因为input_dim_size必须可以被group_num整除,我这边模型修改之后因为把waveform和频谱在通道上拼在一起,所以通道数由原来的3*64变成了3*115,16就除不尽了)

同理,n_heads原来是2,现在除不尽了,要改成3(在configs\model\vocals_test.yaml里修改)

  1. 先看看DTTNet-Pytorch/configs/experiment/vocals_dis.yaml中的trainer部分配置

之前加的resume_from_checkpoint要删掉

  1. 执行命令,开始训练

python train.py experiment=vocals_dis datamodule=musdb_dev14 trainer=default

  1. 发现显存不够,于是把batch从4改成1,相应的,学习率要调小(我之前把batch从8改成4的时候就忘了把lr从0.0002调成0.0001,所以现在要从0.0002除以8,即把lr改成0.000025

评价自己的模型

  1. 先修改一下src\evaluation\separate.py的185行附近:

    with torch.no_grad():
            model.eval()
            for mixture_wav in mix_waves_batched:
                mix_spec = model.stft(mixture_wav.to(device)) 
                spec_hat = model(mix_spec)
                target_wav_hat = model.istft(spec_hat)
                target_wav_hat = target_wav_hat.cpu().detach().numpy()
                target_wav_hats.append(target_wav_hat) # (b, c, t)
    

    改成

    with torch.no_grad():
            model.eval()
            for mixture_wav in mix_waves_batched:
                mixture_wav=mixture_wav.to(device)
                target_wav_hat= model(mixture_wav)
                target_wav_hat = target_wav_hat.cpu().detach().numpy()
                target_wav_hats.append(target_wav_hat) # (b, c, t)
    

    (一定要x=x.to(device)哦,光光x.to(device)没用的)

  2. configs\evaluation.yaml把batch从4改成1?(反正我改成1了)

  3. 这个很重要:把configs\model\vocals_test.yaml里的bn_normBN改成syncBN

    其实在训练时就该改了,不过训练时的配置文件比较复杂,configs\experiment\vocals_dis.yaml把bn_norm设置为syncBN并覆盖了原来的设置BN。

    到了评分时,没用到这个配置文件,就需要改onfigs\model\vocals_test.yaml里的bn_norm了

    那为什么BN不行呢?

    batchnorm的源码开头检查了维度必须为4,syncbatchnorm源码就没这个检查

export ckpt_path="/home/wujunyu/DTTNet-Pytorch/check_points/vocals_vocals_g32_38/checkpoints/epoch=04-step=4050.ckpt"
####上面这个是作者提供的训练好的参数

python run_eval.py model=vocals_test logger.wandb.name=evaluate_new_model_1.0_with_too_high_lr
###注意这个xxxx应替换为一个实际的名称(帮助你识别和管理不同的训练或评估运行)
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值