文章目录
- Pytorch模型的Debug
- Pytorch模型的分解(可与tensorflow模型对应,并得到每一层的输出结果)
- Pytorch模型的训练参数保存(保存为yaml文件,可加载进tensorflow模型)
- tensorflow2的SepFormer模型
- tensorflow模型的build及参数load
- tensorflow模型的测试(model.predict和model.evaluate)
- tensorflow模型的每一层/特定层的输出(可与pytorch模型对应层输出进行对比)
- tensorflow模型的数据集 Dataset (tf.data.dataset)
- tensorflow模型的loss函数 (uPIT-SiSNR)
- tensorflow模型的训练
Pytorch模型的Debug
在重写sepformer之前,可对原pytorch版本的模型进行debug,查看其所有的模型层及参数等。主要基于train.py进行debug,方式有:
- 在
speechbrain.lobes.dual_path.py
的相关模块,如SepformerWrapper
中的Encoder,SBTransformerBlock,Dual_Path_Model,Decoder
中的__init__
函数内进行打断点。可在构建模型时进入断点。
class SepformerWrapper(nn.Module):
"""The wrapper for the sepformer model which combines the Encoder, Masknet and the decoder
https://arxiv.org/abs/2010.13154
Arguments
---------
encoder_kernel_size: int,
The kernel size used in the encoder
encoder_in_nchannels: int,
The number of channels of the input audio
encoder_out_nchannels: int,
The number of filters used in the encoder.
Also, number of channels that would be inputted to the intra and inter blocks.
masknet_chunksize: int,
The chunk length that is to be processed by the intra blocks
masknet_numlayers: int,
The number of layers of combination of inter and intra blocks
masknet_norm: str,
The normalization type to be used in the masknet
Should be one of 'ln' -- layernorm, 'gln' -- globallayernorm
'cln' -- cumulative layernorm, 'bn' -- batchnorm
-- see the select_norm function above for more details
masknet_useextralinearlayer: bool,
Whether or not to use a linear layer at the output of intra and inter blocks
masknet_extraskipconnection: bool,
This introduces extra skip connections around the intra block
masknet_numspks: int,
This determines the number of speakers to estimate
intra_numlayers: int,
This determines the number of layers in the intra block
inter_numlayers: int,
This determines the number of layers in the inter block
intra_nhead: int,
This determines the number of parallel attention heads in the intra block
inter_nhead: int,
This determines the number of parallel attention heads in the inter block
intra_dffn: int,
The number of dimensions in the positional feedforward model in the inter block
inter_dffn: int,
The number of dimensions in the positional feedforward model in the intra block
intra_use_positional: bool,
Whether or not to use positional encodings in the intra block
inter_use_positional: bool,
Whether or not to use positional encodings in the inter block
intra_norm_before: bool
Whether or not we use normalization before the transformations in the intra block
inter_norm_before: bool
Whether or not we use normalization before the transformations in the inter block
Example
-----
>>> model = SepformerWrapper()
>>> inp = torch.rand(1, 160)
>>> result = model.forward(inp)
>>> result.shape
torch.Size([1, 160, 2])
"""
def __init__(
self,
encoder_kernel_size=16,
encoder_in_nchannels=1,
encoder_out_nchannels=256,
masknet_chunksize=250,
masknet_numlayers=2,
masknet_norm="ln",
masknet_useextralinearlayer=False,
masknet_extraskipconnection=True,
masknet_numspks=2,
intra_numlayers=8,
inter_numlayers=8,
intra_nhead=8,
inter_nhead=8,
intra_dffn=1024,
inter_dffn=1024,
intra_use_positional=True,
inter_use_positional=True,
intra_norm_before=True,
inter_norm_before=True,
):
super(SepformerWrapper, self).__init__()
self.encoder = Encoder(
kernel_size=encoder_kernel_size,
out_channels=encoder_out_nchannels,
in_channels=encoder_in_nchannels,
)
intra_model = SBTransformerBlock(
num_layers=intra_numlayers,
d_model=encoder_out_nchannels,
nhead=intra_nhead,
d_ffn=intra_dffn,
use_positional_encoding=intra_use_positional,
norm_before=intra_norm_before,
)
inter_model = SBTransformerBlock(
num_layers=inter_numlayers,
d_model=encoder_out_nchannels,
nhead=inter_nhead,
d_ffn=inter_dffn,
use_positional_encoding=inter_use_positional,
norm_before=inter_norm_before,
)
self.masknet = Dual_Path_Model(
in_channels=encoder_out_nchannels,
out_channels=encoder_out_nchannels,
intra_model=intra_model,
inter_model=inter_model,
num_layers=masknet_numlayers,
norm=masknet_norm,
K=masknet_chunksize,
num_spks=masknet_numspks,
skip_around_intra=masknet_extraskipconnection,
linear_layer_after_inter_intra=masknet_useextralinearlayer,
)
self.decoder = Decoder(
in_channels=encoder_out_nchannels,
out_channels=encoder_in_nchannels,
kernel_size=encoder_kernel_size,
stride=encoder_kernel_size // 2,
bias=False,
)
self.num_spks = masknet_numspks
# reinitialize the parameters
for module in [self.encoder, self.masknet, self.decoder]:
self.reset_layer_recursively(module)
def reset_layer_recursively(self, layer):
"""Reinitializes the parameters of the network"""
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
for child_layer in layer.modules():
if layer != child_layer:
self.reset_layer_recursively(child_layer)
def forward(self, mix):
mix_w = self.encoder(mix)
est_mask = self.masknet(mix_w)
mix_w = torch.stack([mix_w] * self.num_spks)
sep_h = mix_w * est_mask
# Decoding
est_source = torch.cat(
[
self.decoder(sep_h[i]).unsqueeze(-1)
for i in range(self.num_spks)
],
dim=-1,
)
# T changed after conv1d in encoder, fix it here
T_origin = mix.size(1)
T_est = est_source.size(1)
if T_origin > T_est:
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
else:
est_source = est_source[:, :T_origin, :]
return est_source
- 或者在上述模块的
forward
模块打断点,可在实例调用时进入断点。实例代码在train.py
中为:
- 输入一段1024点pcm数据
- output_encoder 为encoder层输出
- output_masknet 为masknet层输出
- output_decoder_1和output_decoder_2为分离的两路语音
import torch
print(separator.modules)
x = torch.randn(1, 1024)
output_encoder = separator.modules['encoder'](x) # (1, 256, 127)
output_masknet = separator.modules['masknet'](output_encoder)
output_decoder_0 = separator.modules['decoder'](output_masknet[0])
output_decoder_1 = separator.modules['decoder'](output_masknet[1])
总的train.py中__main__函数为:
if __name__ == "__main__":
# Load hyperparameters file with command-line overrides
# hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) # todo
# rewrite as:
hparams_file = '/home/zhaodeng/SoundPlus/speechbrain_4csdn/speechbrain/recipes/WSJ0Mix/separation/train/sepformer.yaml'
hparams_file, run_opts, overrides = sb.parse_arguments([hparams_file])
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Initialize ddp (useful only for multi-GPU DDP training)
sb.utils.distributed.ddp_init_group(run_opts)
# Logger info
logger = logging.getLogger(__name__)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Check if wsj0_tr is set with dynamic mixing
if hparams["dynamic_mixing"] and not os.path.exists(hparams["wsj0_tr"]):
print(
"Please, specify a valid wsj0_tr folder when using dynamic mixing"
)
sys.exit(1)
# Data preparation # todo: replace
# from recipes.WSJ0Mix.prepare_data import prepare_wsjmix # noqa
#
# run_on_main(
# prepare_wsjmix,
# kwargs={
# "datapath": hparams["data_folder"],
# "savepath": hparams["save_folder"],
# "n_spks": hparams["num_spks"],
# "skip_prep": hparams["skip_prep"],
# },
# )
from prepare_data import prepare_wsjmix # noqa
run_on_main(
prepare_wsjmix,
kwargs={
"datapath": hparams["data_folder"],
"savepath": hparams["save_folder"],
"n_spks": hparams["num_spks"],
"skip_prep": hparams["skip_prep"],
},
)
# Create dataset objects
if hparams["dynamic_mixing"]:
if hparams["num_spks"] == 2:
from dynamic_mixing import dynamic_mix_data_prep # noqa
train_data = dynamic_mix_data_prep(hparams)
elif hparams["num_spks"] == 3:
from dynamic_mixing import dynamic_mix_data_prep_3mix # noqa
train_data = dynamic_mix_data_prep_3mix(hparams)
else:
raise ValueError(
"The specified number of speakers is not supported."
)
_, valid_data, test_data = dataio_prep(hparams)
else:
train_data, valid_data, test_data = dataio_prep(hparams)
# Brain class initialization
separator = Separation(
modules=hparams["modules"],
opt_class=hparams["optimizer"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
## todo
import torch
print(separator.modules)
x = torch.randn(1, 1024)
output_encoder = separator.modules['encoder'](x) # (1, 256, 127)
output_masknet = separator.modules['masknet'](output_encoder)
output_decoder_0 = separator.modules['decoder'](output_masknet[0])
output_decoder_1 = separator.modules['decoder'](output_masknet[1])
# re-initialize the parameters
for module in separator.modules.values():
separator.reset_layer_recursively(module)
if not hparams["test_only"]:
# Training
separator.fit(
separator.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["dataloader_opts"],
valid_loader_kwargs=hparams["dataloader_opts"],
)
# Eval
separator.evaluate(test_data, min_key="si-snr")
separator.save_results(test_data)
Pytorch模型的分解(可与tensorflow模型对应,并得到每一层的输出结果)
根据对pytorch版本的sepformer进行debug,其实可将模型每一层或者每一步都进行分解,并输出对应的结果,如下:
ps:在此输入一固定的语音文件,模型参数为加载进来的。
def print_module_output():
model = separator.from_hparams(
source="/media/me/nvme2n1/SoundPlus/SpeechBrain/speechbrain-develop/recipes/WSJ0Mix/separation/zd/load/results-v3/sepformer/1234/save/CKPT+2021-06-10+16-52-44+00",
savedir='./sepformer_train_3990_v3')
output_write = open('layer_output_th.txt', 'w+')
# test
wav_file = '/media/me/nvme2n1/SoundPlus/SpeechBrain/sepformer_tf2/snr0_8k_16b.wav'
batch, fs_file = torchaudio.load(wav_file)
batch = batch[:, :4*8000]
batch = torch.nn.functional.pad(batch, pad=[0, 32000 - len(batch[0])]) # optional
# encoder
mix_w_our = model.modules.encoder(batch)
# masknet
est_mask = model.modules.masknet.norm(mix_w_our)
est_mask = model.modules.masknet.conv1d(est_mask)
# sementation
segment, gap = model.modules.masknet._Segmentation(est_mask, model.modules.masknet.K)
# dual_path
est_mask = segment
for dual_i in range(2):
est_mask0 = est_mask
B, N, K, S = est_mask.shape
est_mask = est_mask.permute(0, 3, 2, 1).contiguous().view(B * S, K, N)
## intra_mdl:pos_enc
# positional embedding
pos_enc = model.modules.masknet.dual_mdl[dual_i].intra_mdl.pos_enc(est_mask)
est_mask = pos_enc + est_mask
# intra_mdl.mdl: layers + norm
src = est_mask
output_our_list = []
for layer_i in range(8):
src1 = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].norm1(src)
output, self_attns = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].self_att(
src1,
src1,
src1,
attn_mask=None,
key_padding_mask=None,
)
src = src + model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].dropout1(output)
src1 = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].norm2(src)
output = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].pos_ffn(src1)
src = src + model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.layers[layer_i].dropout2(output)
output_our_list.append(src)
est_mask = model.modules.masknet.dual_mdl[dual_i].intra_mdl.mdl.norm(src)
## intra_mdl end
# intra_norm
est_mask = est_mask.view(B, S, K, N)
est_mask = est_mask.permute(0, 3, 2, 1).contiguous()
est_mask = model.modules.masknet.dual_mdl[dual_i].intra_norm(est_mask) # ok
intra = est_mask + est_mask0
## inter_mdl
# pos_enc
inter = intra.permute(0, 2, 3, 1).contiguous().view(B * K, S, N)
pos_enc = model.modules.masknet.dual_mdl[dual_i].inter_mdl.pos_enc(inter)
est_mask = pos_enc + inter
# inter_mdl: layers + norm
src = est_mask
for layer_i in range(8):
src1 = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].norm1(src)
output, self_attns = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].self_att(
src1,
src1,
src1,
attn_mask=None,
key_padding_mask=None,
)
src = src + model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].dropout1(output)
src1 = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].norm2(src)
output = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].pos_ffn(src1)
src = src + model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.layers[layer_i].dropout2(output)
est_mask = model.modules.masknet.dual_mdl[dual_i].inter_mdl.mdl.norm(src)
# iner_norm
inter = est_mask.view(B, K, S, N)
inter = inter.permute(0, 3, 1, 2).contiguous()
inter = model.modules.masknet.dual_mdl[dual_i].inter_norm(inter)
est_mask = inter + intra
# dual_mdl: in total
# x = segment
# for i in range(model.modules.masknet.num_layers):
# x = model.modules.masknet.dual_mdl[i](x)
# prelu
est_mask = model.modules.masknet.prelu(est_mask)
# conv2d
est_mask = model.modules.masknet.conv2d(est_mask)
# over_add
B, _, K, S = est_mask.shape
est_mask = est_mask.view(B * 2, -1, K, S)
est_mask = model.modules.masknet._over_add(est_mask, gap)
# output * output_gate
output_o = model.modules.masknet.output(est_mask)
output_gate = model.modules.masknet.output_gate(est_mask)
est_mask = output_o * output_gate
# conv1d
est_mask = model.modules.masknet.end_conv1x1(est_mask)
_, N, L = est_mask.shape
est_mask = est_mask.view(B, model.modules.masknet.num_spks, N, L)
est_mask = model.modules.masknet.activation(est_mask)
est_mask_our = est_mask.transpose(0, 1)
# or masknet in one command
# est_mask_orig = model.modules.masknet(mix_w_our)
# decoder
mix_w_our = torch.stack([mix_w_our] * model.hparams.num_spks)
sep_h_our = mix_w_our * est_mask_our
# Decoding
est_source_our = torch.cat(
[
model.modules.decoder(sep_h_our[i]).unsqueeze(-1)
for i in range(model.hparams.num_spks)
],
dim=-1,
)
# output the output of model
layer_output = est_source_our.numpy()
for i in range(len(layer_output[0])):
for j in range(len(layer_output[0][i])):
output_write.write(str(layer_output[0][i][j]) + ' ')
output_write.write('\n')
output_write.write('\n')
output_write.write('\n')
### IMPORTANT ###
est_source = est_source_our / est_source_our.max(dim=1, keepdim=True)[0]
# save to wav file
torchaudio.save("snr0_1.wav", est_source[:, :, 0].detach().cpu(), 8000)
torchaudio.save("snr0_2.wav", est_source[:, :, 1].detach().cpu(), 8000)
return mix_w_our, est_mask, est_source
Pytorch模型的训练参数保存(保存为yaml文件,可加载进tensorflow模型)
在此可将pytorch模型(已经过训练)load进来,再将每一层的参数按照特定名字save到一个词典,并保存到yaml文件。
那么,tensorflow模型即可load该yaml,并根据名字load参数。
def loadmodel_and_dumpyaml():
from speechbrain.pretrained import SepformerSeparation as separator
import numpy as np
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
model = separator.from_hparams(
source="/media/me/nvme2n1/SoundPlus/SpeechBrain/speechbrain-develop/recipes/WSJ0Mix/separation/zd/load/results-v3/sepformer/1234/save/CKPT+2021-06-10+16-52-44+00",
savedir='./sepformer_train_3990_v3')
# summary(model, (1, 32000, 1))
yaml_key_value = {
}
# encoder
scope = 'encoder'
key = scope + '/conv1d/kernel:0'
value = np.array(model.modules['encoder'].conv1d.weight)
yaml_key_value[key] = value.transpose(2, 1, 0)#value.reshape([16, 1, 256])
# masknet
scope = 'masknet'
masknet = model.modules[