【语音分离】SpeechBrain 之 SepFormer (TensorFlow 2.4.0)

这篇博客介绍了如何将PyTorch的SepFormer模型转换到TensorFlow 2.4.0,包括模型的分解、参数保存、加载、测试以及训练过程。内容涵盖PyTorch模型的调试、参数保存为yaml文件,TensorFlow模型的构建、参数加载、预测和评估,以及使用Dataset进行数据处理。同时,讨论了在TensorFlow中实现的uPIT-SiSNR损失函数和训练回调机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch模型的Debug

在重写sepformer之前,可对原pytorch版本的模型进行debug,查看其所有的模型层及参数等。主要基于train.py进行debug,方式有:

  1. speechbrain.lobes.dual_path.py的相关模块,如SepformerWrapper中的Encoder,SBTransformerBlock,Dual_Path_Model,Decoder中的__init__函数内进行打断点。可在构建模型时进入断点。
class SepformerWrapper(nn.Module):
    """The wrapper for the sepformer model which combines the Encoder, Masknet and the decoder
    https://arxiv.org/abs/2010.13154

    Arguments
    ---------

    encoder_kernel_size: int,
        The kernel size used in the encoder
    encoder_in_nchannels: int,
        The number of channels of the input audio
    encoder_out_nchannels: int,
        The number of filters used in the encoder.
        Also, number of channels that would be inputted to the intra and inter blocks.
    masknet_chunksize: int,
        The chunk length that is to be processed by the intra blocks
    masknet_numlayers: int,
        The number of layers of combination of inter and intra blocks
    masknet_norm: str,
        The normalization type to be used in the masknet
        Should be one of 'ln' -- layernorm, 'gln' -- globallayernorm
                         'cln' -- cumulative layernorm, 'bn' -- batchnorm
                         -- see the select_norm function above for more details
    masknet_useextralinearlayer: bool,
        Whether or not to use a linear layer at the output of intra and inter blocks
    masknet_extraskipconnection: bool,
        This introduces extra skip connections around the intra block
    masknet_numspks: int,
        This determines the number of speakers to estimate
    intra_numlayers: int,
        This determines the number of layers in the intra block
    inter_numlayers: int,
        This determines the number of layers in the inter block
    intra_nhead: int,
        This determines the number of parallel attention heads in the intra block
    inter_nhead: int,
        This determines the number of parallel attention heads in the inter block
    intra_dffn: int,
        The number of dimensions in the positional feedforward model in the inter block
    inter_dffn: int,
        The number of dimensions in the positional feedforward model in the intra block
    intra_use_positional: bool,
        Whether or not to use positional encodings in the intra block
    inter_use_positional: bool,
        Whether or not to use positional encodings in the inter block
    intra_norm_before: bool
        Whether or not we use normalization before the transformations in the intra block
    inter_norm_before: bool
        Whether or not we use normalization before the transformations in the inter block

    Example
    -----
    >>> model = SepformerWrapper()
    >>> inp = torch.rand(1, 160)
    >>> result = model.forward(inp)
    >>> result.shape
    torch.Size([1, 160, 2])
    """

    def __init__(
        self,
        encoder_kernel_size=16,
        encoder_in_nchannels=1,
        encoder_out_nchannels=256,
        masknet_chunksize=250,
        masknet_numlayers=2,
        masknet_norm="ln",
        masknet_useextralinearlayer=False,
        masknet_extraskipconnection=True,
        masknet_numspks=2,
        intra_numlayers=8,
        inter_numlayers=8,
        intra_nhead=8,
        inter_nhead=8,
        intra_dffn=1024,
        inter_dffn=1024,
        intra_use_positional=True,
        inter_use_positional=True,
        intra_norm_before=True,
        inter_norm_before=True,
    ):

        super(SepformerWrapper, self).__init__()
        self.encoder = Encoder(
            kernel_size=encoder_kernel_size,
            out_channels=encoder_out_nchannels,
            in_channels=encoder_in_nchannels,
        )
        intra_model = SBTransformerBlock(
            num_layers=intra_numlayers,
            d_model=encoder_out_nchannels,
            nhead=intra_nhead,
            d_ffn=intra_dffn,
            use_positional_encoding=intra_use_positional,
            norm_before=intra_norm_before,
        )

        inter_model = SBTransformerBlock(
            num_layers=inter_numlayers,
            d_model=encoder_out_nchannels,
            nhead=inter_nhead,
            d_ffn=inter_dffn,
            use_positional_encoding=inter_use_positional,
            norm_before=inter_norm_before,
        )

        self.masknet = Dual_Path_Model(
            in_channels=encoder_out_nchannels,
            out_channels=encoder_out_nchannels,
            intra_model=intra_model,
            inter_model=inter_model,
            num_layers=masknet_numlayers,
            norm=masknet_norm,
            K=masknet_chunksize,
            num_spks=masknet_numspks,
            skip_around_intra=masknet_extraskipconnection,
            linear_layer_after_inter_intra=masknet_useextralinearlayer,
        )
        self.decoder = Decoder(
            in_channels=encoder_out_nchannels,
            out_channels=encoder_in_nchannels,
            kernel_size=encoder_kernel_size,
            stride=encoder_kernel_size // 2,
            bias=False,
        )
        self.num_spks = masknet_numspks

        # reinitialize the parameters
        for module in [self.encoder, self.masknet, self.decoder]:
            self.reset_layer_recursively(module)

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the network"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

    def forward(self, mix):

        mix_w = self.encoder(mix)
        est_mask = self.masknet(mix_w)
        mix_w = torch.stack([mix_w] * self.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.decoder(sep_h[i]).unsqueeze(-1)
                for i in range(self.num_spks)
            ],
            dim=-1,
        )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        return est_source
  1. 或者在上述模块的forward模块打断点,可在实例调用时进入断点。实例代码在train.py中为:
  • 输入一段1024点pcm数据
  • output_encoder 为encoder层输出
  • output_masknet 为masknet层输出
  • output_decoder_1和output_decoder_2为分离的两路语音
import torch
print(separator.modules)

x = torch.randn(1, 1024)
output_encoder = separator.modules['encoder'](x)  # (1, 256, 127)
output_masknet = separator.modules['masknet'](output_encoder)
output_decoder_0 = separator.modules['decoder'](output_masknet[0])
output_decoder_1 = separator.modules['decoder'](output_masknet[1])

总的train.py中__main__函数为:

if __name__ == "__main__":

    # Load hyperparameters file with command-line overrides
    # hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) # todo
    # rewrite as:
    hparams_file = '/home/zhaodeng/SoundPlus/speechbrain_4csdn/speechbrain/recipes/WSJ0Mix/separation/train/sepformer.yaml'
    hparams_file, run_opts, overrides = sb.parse_arguments([hparams_file])

    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # Initialize ddp (useful only for multi-GPU DDP training)
    sb.utils.distributed.ddp_init_group(run_opts)

    # Logger info
    logger = logging.getLogger(__name__)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Check if wsj0_tr is set with dynamic mixing
    if hparams["dynamic_mixing"] and not os.path.exists(hparams["wsj0_tr"]):
        print(
            "Please, specify a valid wsj0_tr folder when using dynamic mixing"
        )
        sys.exit(1)

    # Data preparation # todo: replace
    # from recipes.WSJ0Mix.prepare_data import prepare_wsjmix  # noqa
    #
    # run_on_main(
    #    prepare_wsjmix,
    #    kwargs={
   
    #        "datapath": hparams["data_folder"],
    #        "savepath": hparams["save_folder"],
    #        "n_spks": hparams["num_spks"],
    #        "skip_prep": hparams["skip_prep"],
    #    },
    # )
    from prepare_data import prepare_wsjmix  # noqa
    run_on_main(
        prepare_wsjmix,
        kwargs={
   
            "datapath": hparams["data_folder"],
            "savepath": hparams["save_folder"],
            "n_spks": hparams["num_spks"],
            "skip_prep": hparams["skip_prep"],
        },
    )

    # Create dataset objects
    if hparams["dynamic_mixing"]:

        if hparams["num_spks"] == 2:
            from dynamic_mixing import dynamic_mix_data_prep  # noqa

            train_data = dynamic_mix_data_prep(hparams)
        elif hparams["num_spks"] == 3:
            from dynamic_mixing import dynamic_mix_data_prep_3mix  # noqa

            train_data = dynamic_mix_data_prep_3mix(hparams)
        else:
            raise ValueError(
                "The specified number of speakers is not supported."
            )
        _, valid_data, test_data = dataio_prep(hparams)
    else:
        train_data, valid_data, test_data = dataio_prep(hparams)

    # Brain class initialization
    separator = Separation(
        modules=hparams["modules"],
        opt_class=hparams["optimizer"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    ## todo
    import torch
    print(separator.modules)

    x = torch.randn(1, 1024)
    output_encoder = separator.modules['encoder'](x)  # (1, 256, 127)
    output_masknet = separator.modules['masknet'](output_encoder)
    output_decoder_0 = separator.modules['decoder'](output_masknet[0])
    output_decoder_1 = separator.modules['decoder'](output_masknet[1])

    # re-initialize the parameters
    for module in separator.modules.values():
        separator.reset_layer_recursively(module)

    if not hparams["test_only"]:
        # Training
        separator.fit(
            separator.hparams.epoch_counter,
            train_data,
            valid_data,
            train_loader_kwargs=hparams["dataloader_opts"],
            valid_loader_kwargs=hparams["dataloader_opts"],
        )

    # Eval
    separator.evaluate(test_data, min_key="si-snr")
    separator.save_results(test_data)

Pytorch模型的分解(可与tensorflow模型对应,并得到每一层的输出结果)

根据对pytorch版本的sepformer进行debug,其实可将模型每一层或者每一步都进行分解,并输出对应的结果,如下:

ps:在此输入一固定的语音文件,模型参数为加载进来的。

def print_module_output():
    model = separator.from_hparams(
        source="/media/me/nvme2n1/SoundPlus/SpeechBrain/speechbrain-develop/recipes/WSJ0Mix/separation/zd/load/results-v3/sepformer/1234/save/CKPT+2021-06-10+16-52-44+00",
        savedir='./sepformer_train_3990_v3')

    output_write = open('layer_output_th.txt', 'w+')

    # test
    wav_file = '/media/me/nvme2n1/SoundPlus/SpeechBrain/sepformer_tf2/snr0_8k_16b.wav'
    batch, fs_file = torchaudio.load(wav_file)
    batch = batch[:, :4*8000]
    batch = torch.nn.functional.pad(batch, pad=[0, 32000 - len(batch[0])]) # optional

    # encoder
    mix_w_our = model.modules.encoder(batch)

    # masknet
    est_mask = model.modules.masknet.norm(mix_w_our)
    est_mask = model.modules.masknet.conv1d(est_mask)
    # sementation
    segment, gap = model.modules.masknet._Segmentation(est_mask, model.modules.masknet.K)

    # dual_path
    est_mask = segment
    for dual_i in range(2):
        est_mask0 = est_mask
        B, N, K, S = est_mask.shape
        est_mask = est_mask.permute(0, 3, 2, 1).contiguous().view(B * S, K, N)

        ## intra_mdl:pos_enc
        # positional embedding
        pos_enc = model.modules.masknet.dual_mdl[dual_i].intra_mdl.pos_enc(est_mask)
        est_mask = pos_enc + est_mask

        # intra_mdl.mdl: layers + norm
        src = est_mask
        output_our_list = []
        for layer_i in range(8):
            src1 = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].norm1(src)
            output, self_attns = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].self_att(
                src1,
                src1,
                src1,
                attn_mask=None,
                key_padding_mask=None,
            )
            src = src + model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].dropout1(output)
            src1 = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].norm2(src)
            output = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].pos_ffn(src1)
            src = src + model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].dropout2(output)
            output_our_list.append(src)
        est_mask = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.norm(src)
        ## intra_mdl end
        
        # intra_norm
        est_mask = est_mask.view(B, S, K, N)
        est_mask = est_mask.permute(0, 3, 2, 1).contiguous()
        est_mask = model.modules.masknet.dual_mdl[dual_i].intra_norm(est_mask) # ok
        intra = est_mask + est_mask0

        ## inter_mdl
        # pos_enc
        inter = intra.permute(0, 2, 3, 1).contiguous().view(B * K, S, N)
        pos_enc = model.modules.masknet.dual_mdl[dual_i].inter_mdl.pos_enc(inter)
        est_mask = pos_enc + inter

        # inter_mdl: layers + norm
        src = est_mask
        for layer_i in range(8):
            src1 = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].norm1(src)
            output, self_attns = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].self_att(
                src1,
                src1,
                src1,
                attn_mask=None,
                key_padding_mask=None,
            )
            src = src + model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].dropout1(output)
            src1 = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].norm2(src)
            output = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].pos_ffn(src1)
            src = src + model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].dropout2(output)
        est_mask = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.norm(src)
        
        # iner_norm
        inter = est_mask.view(B, K, S, N)
        inter = inter.permute(0, 3, 1, 2).contiguous()
        inter = model.modules.masknet.dual_mdl[dual_i].inter_norm(inter)
        est_mask = inter + intra
    
    # dual_mdl: in total
    # x = segment
    # for i in range(model.modules.masknet.num_layers):
    #     x = model.modules.masknet.dual_mdl[i](x)

    # prelu
    est_mask = model.modules.masknet.prelu(est_mask)
    # conv2d
    est_mask = model.modules.masknet.conv2d(est_mask)
    # over_add
    B, _, K, S = est_mask.shape
    est_mask = est_mask.view(B * 2, -1, K, S)
    est_mask = model.modules.masknet._over_add(est_mask, gap)
    
    # output * output_gate
    output_o = model.modules.masknet.output(est_mask)
    output_gate = model.modules.masknet.output_gate(est_mask)
    est_mask = output_o * output_gate

    # conv1d
    est_mask = model.modules.masknet.end_conv1x1(est_mask)
    _, N, L = est_mask.shape
    est_mask = est_mask.view(B, model.modules.masknet.num_spks, N, L)
    est_mask = model.modules.masknet.activation(est_mask)
    est_mask_our = est_mask.transpose(0, 1)
    
    # or masknet in one command
    # est_mask_orig = model.modules.masknet(mix_w_our)
    
    # decoder
    mix_w_our = torch.stack([mix_w_our] * model.hparams.num_spks)
    sep_h_our = mix_w_our * est_mask_our
    # Decoding
    est_source_our = torch.cat(
        [
            model.modules.decoder(sep_h_our[i]).unsqueeze(-1)
            for i in range(model.hparams.num_spks)
        ],
        dim=-1,
    )
    # output the output of model
    layer_output = est_source_our.numpy()
    for i in range(len(layer_output[0])):
        for j in range(len(layer_output[0][i])):
            output_write.write(str(layer_output[0][i][j]) + ' ')
        output_write.write('\n')
    output_write.write('\n')
    output_write.write('\n')

    ### IMPORTANT ###
    est_source = est_source_our / est_source_our.max(dim=1, keepdim=True)[0]
    # save to wav file
    torchaudio.save("snr0_1.wav", est_source[:, :, 0].detach().cpu(), 8000)
    torchaudio.save("snr0_2.wav", est_source[:, :, 1].detach().cpu(), 8000)

    return mix_w_our, est_mask, est_source

Pytorch模型的训练参数保存(保存为yaml文件,可加载进tensorflow模型)

在此可将pytorch模型(已经过训练)load进来,再将每一层的参数按照特定名字save到一个词典,并保存到yaml文件。

那么,tensorflow模型即可load该yaml,并根据名字load参数。

def loadmodel_and_dumpyaml():
    from speechbrain.pretrained import SepformerSeparation as separator
    import numpy as np
    import yaml
    try:
        from yaml import CLoader as Loader, CDumper as Dumper
    except ImportError:
        from yaml import Loader, Dumper

    model = separator.from_hparams(
        source="/media/me/nvme2n1/SoundPlus/SpeechBrain/speechbrain-develop/recipes/WSJ0Mix/separation/zd/load/results-v3/sepformer/1234/save/CKPT+2021-06-10+16-52-44+00",
        savedir='./sepformer_train_3990_v3')
    # summary(model, (1, 32000, 1))

    yaml_key_value = {
   }

    # encoder
    scope = 'encoder'
    key = scope + '/conv1d/kernel:0'
    value = np.array(model.modules['encoder'].conv1d.weight)
    yaml_key_value[key] = value.transpose(2, 1, 0)#value.reshape([16, 1, 256])

    # masknet
    scope = 'masknet'
    masknet = model.modules[
### TensorFlow 2.4.0 安装指南 对于希望安装 TensorFlow 2.4.0 的用户而言,官方提供了详细的指导说明。确保环境配置正确至关重要,尤其是当涉及到特定版本的需求时。 #### Python 版本需求 TensorFlow 2.4.0 支持 Python 3.6–3.8[^1]。这意味着,在准备环境中应选择上述范围内的Python版本之一来匹配此版TensorFlow的要求。 #### CUDA 和 cuDNN 配置 由于提到的CUDA版本为11.0,这确实对应于TensorFlow 2.4.0的支持列表内。因此,为了使GPU支持正常工作,除了安装兼容的CUDA外,还需要相应地设置cuDNN库。如果`import tensorflow as tf`执行时不抛出错误,则表明当前环境下的cuDNN已满足条件;反之则需手动调整至适当版本[^3]。 ```bash pip install tensorflow==2.4.0 ``` 这段命令用于指定安装TensorFlow的确切版本号,即2.4.0。然而,考虑到依赖冲突问题——特别是关于`typing-extensions`包的情况,可能需要先解决这些潜在障碍再继续安装过程。 ### 使用文档概览 TensorFlow官方网站提供详尽的教程和API参考手册,帮助开发者快速上手并深入理解框架功能。针对不同应用场景(如图像识别、自然语言处理等),都有专门章节介绍如何构建模型以及优化性能技巧。此外,还包含了大量实例代码片段供学习者模仿练习。 ### 更新日志要点 在TensorFlow 2.4.0中引入了一些重要改进: - **增强Keras集成**:简化了高层级接口操作流程,使得创建复杂神经网络结构变得更加直观便捷。 - **分布式训练加强**:新增特性允许更灵活高效地管理多设备间的协作计算任务。 - **图形优化器升级**:通过内部机制革新提高了整体运算效率,减少了不必要的资源消耗。 值得注意的是,随着新特性的加入,某些旧有函数可能会被标记为废弃状态或行为有所改变,所以在迁移现有项目到新版之前务必仔细阅读相关变更记录以评估影响程度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值