FastSpeech2代码理解之模型实现(一)

本文介绍了FastSpeech2,一个基于Transformer的语音合成模型。模型包含Encoder、Variance Adaptor和Decoder,其中Variance Adaptor用于结合时长、音高和能量信息。详细探讨了各部分的实现,包括/transformer/Layers.py中的Transformer块,/transformer/Models.py中的Encoder和Decoder,以及/model/modules.py中的Variance Predictor。
摘要由CSDN通过智能技术生成


参考

参考项目:FastSpeech2的github实现
FastSpeech2论文
FastSpeech2模型代码分析

FastSpeech2

FastSpeech2是一个基于Transformer的端到端语音合成模型,其结构如下:

在这里插入图片描述
Encoder将音素序列转换到隐藏序列,然后Variance Adaptor将不同的变量信息,如时长、音高、能量加入到到隐藏序列中,最终解码器将隐藏序列转换为梅尔谱序列。

1. FastSpeech2实现

FastSpeech2的实现位于/model/fastspeech2.py中:

class FastSpeech2(nn.Module):
    """ FastSpeech2 """

    def __init__(self, preprocess_config, model_config):
        super(FastSpeech2, self).__init__()
        self.model_config = model_config

        self.encoder = Encoder(model_config)
        self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
        self.decoder = Decoder(model_config)
        self.mel_linear = nn.Linear(
            model_config["transformer"]["decoder_hidden"],
            preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
        )
        self.postnet = PostNet()

        self.speaker_emb = None
        if model_config["multi_speaker"]:
            with open(
                os.path.join(
                    preprocess_config["path"]["preprocessed_path"], "speakers.json"
                ),
                "r",
            ) as f:
                n_speaker = len(json.load(f))
            self.speaker_emb = nn.Embedding(
                n_speaker,
                model_config["transformer"]["encoder_hidden"],
            )

2. 模型结构

使用如下代码打印参考项目的FastSpeech2模型结构

from model import FastSpeech2
from utils.tools import get_configs_of

preprocess_config, model_config, train_config = get_configs_of("AISHELL3")

print(FastSpeech2(preprocess_config, model_config))

其中get_configs_of是原参考项目没有的,在/utils/tools.py中增加如下代码

import yaml

def get_configs_of(dataset):
    config_dir = os.path.join("./config", dataset)
    preprocess_config = yaml.load(open(
        os.path.join(config_dir, "preprocess.yaml"), "r"), Loader=yaml.FullLoader)
    model_config = yaml.load(open(
        os.path.join(config_dir, "model.yaml"), "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(
        os.path.join(config_dir, "train.yaml"), "r"), Loader=yaml.FullLoader)
    return preprocess_config, model_config, train_config

FastSpeech2结构打印如下,其中postnet模块在FastSpeech2的论文中是没有的,是此参考项目的作者增加的:

FastSpeech2(
  (encoder): Encoder(
    (src_word_emb): Embedding(361, 256, padding_idx=0)
    (layer_stack): ModuleList(
      (0): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (1): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (att
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

象牙塔♛

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值