fastspeech复现github项目--模型训练

在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/xcmyz/FastSpeech。该仓库使用的数据集为LJSpeech,数据处理部分的代码见笔记“fastspeech复现github项目–数据准备”、模型构建的代码见笔记“fastspeech复现github项目–模型构建”。本笔记对FastSpeech模型训练相关代码进行详细注释,主要代码是仓库中的dataset.py、loss.py、optimizer.py、train.py、eval.py。

dataset.py

该文件是主要用于数据加载和数据转换,将文本、持续时间和mel谱图序列加载封装至定义的BufferDataset对象中,然后定义回调函数collate_fn_tensor将对数据进行pad等操作,转换为模型训练所需的格式

import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import math
import time
import os

import hparams
import audio

from utils import process_text, pad_1D, pad_2D
from utils import pad_1D_tensor, pad_2D_tensor
from text import text_to_sequence
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_data_to_buffer():
    buffer = list()
    # 将全部的音频文本读取到一个列表对象中,text是一个列表,每一个元素是一个字符串,即一个音频对应的文本
    text = process_text(os.path.join("data", "train.txt"))

    start = time.perf_counter()
    for i in tqdm(range(len(text))):

        mel_gt_name = os.path.join(
            hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i+1))
        mel_gt_target = np.load(mel_gt_name)  # 加载文本对应的音频文件的mel谱图
        duration = np.load(os.path.join(
            hparams.alignment_path, str(i)+".npy"))  # 加载对应的持续时间
        character = text[i][0:len(text[i])-1]  # 删除最后的换行符
        character = np.array(
            text_to_sequence(character, hparams.text_cleaners))  # 将英文文本转换为数值序列,相当于分词
        print(sum(duration))

        # character和duration的长度一致,即duration中的i的值,表示character中i位置的数值出现的次数
        character = torch.from_numpy(character)
        # dutation中所有数值之和与mel的长度相等,即character经过duration调整后,文本长度将于mel谱图长度对齐
        duration = torch.from_numpy(duration)
        mel_gt_target = torch.from_numpy(mel_gt_target)
        # 将一个音频文件的文本、持续时间和mel谱图数据组合成一个元组对象存在在列表中
        buffer.append({"text": character, "duration": duration,
                       "mel_target": mel_gt_target})

    end = time.perf_counter()
    print("cost {:.2f}s to load all data into buffer.".format(end-start))

    return buffer


class BufferDataset(Dataset):
    def __init__(self, buffer):
        self.buffer = buffer  # 加载所有数据
        self.length_dataset = len(self.buffer)  # 数据集总数量

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        return self.buffer[idx]


def reprocess_tensor(batch, cut_list):
    '''
    以传入的batch数据和对应的序列索引,给文本序列、mel谱图序列建立位置信息,同时将其封装在一起输出
    @param batch:一个大batch的数据
    @param cut_list:一个real batch大小的索引列表,其对应的文本长度从达到小降序排列
    @return:
    '''
    texts = [batch[ind]["text"] for ind in cut_list]  # batch中的文本
    mel_targets = [batch[ind]["mel_target"] for ind in cut_list]  # batch中的gt梅尔谱图
    durations = [batch[ind]["duration"] for ind in cut_list]  # batch中的duration时间

    length_text = np.array([])  # 存储所有文本序列的长度大小
    for text in texts:
        length_text = np.append(length_text, text.size(0))

    src_pos = list()
    max_len = int(max(length_text))  # 最大文本长度
    for length_src_row in length_text:
        # 给每个文本生成src_pos,从1到文本的长度,如果长度小于max_len,对应部分用0填充
        src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
                              (0, max_len-int(length_src_row)), 'constant'))
    src_pos = torch.from_numpy(np.array(src_pos))

    length_mel = np.array(list())  # 存储所有mel谱图序列的长度大小
    for mel in mel_targets:
        length_mel = np.append(length_mel, mel.size(0))

    mel_pos = list()
    max_mel_len = int(max(length_mel))  # 最大mel谱图序列长度
    for length_mel_row in length_mel:
        # 给每个mel谱图序列生成mel_pos,从1到序列的长度,如果长度小于max_mel_len,对应部分用0填充
        mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
                              (0, max_mel_len-int(length_mel_row)), 'constant'))
    mel_pos = torch.from_numpy(np.array(mel_pos))

    texts = pad_1D_tensor(texts)  # 将所有的文本都pad到文本的最大长度
    durations = pad_1D_tensor(durations)  # 将所有的duration持续时间pad到最大长度
    mel_targets = pad_2D_tensor(mel_targets)  # 将所有mel谱图序列pad到最大长度

    out = {"text": texts,
           "mel_target": mel_targets,
           "duration": durations,
           "mel_pos": mel_pos,
           "src_pos": src_pos,
           "mel_max_len": max_mel_len}

    return out


# 构建Loader时数据转换的回调函数
def collate_fn_tensor(batch):
    len_arr = np.array([d["text"].size(0) for d in batch])  # 一个batch中文本序列的长度列表
    index_arr = np.argsort(-len_arr)  # 对len_arr进行降序排序后,从大到小返回值在原列表中的索引
    batchsize = len(batch)
    real_batchsize = batchsize // hparams.batch_expand_size

    cut_list = list()
    for i in range(hparams.batch_expand_size):
        cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])  # 将index_arr分成hparams.batch_expand_size段

    output = list()
    for i in range(hparams.batch_expand_size):
        output.append(reprocess_tensor(batch, cut_list[i]))

    return output  # output中一个元素就是一个real batch的数据


if __name__ == "__main__":
    # TEST
    # get_data_to_buffer()

    a = np.array([3, 5, 1, 8, 4])
    print(np.argsort(-a))

其中对数据进行pad是使用了utils.py文件中的两个函数,如下所示

# 对传入的一维张量进行pad
def pad_1D_tensor(inputs, PAD=0):

    def pad_data(x, length, PAD):
        x_padded = F.pad(x, (0, length - x.shape[0]))
        return x_padded

    max_len = max((len(x) for x in inputs))
    padded = torch.stack([pad_data(x, max_len, PAD) for x in inputs])

    return padded

# 对二维张量进行pad
def pad_2D_tensor(inputs, maxlen=None):

    def pad(x, max_len):  #
        if x.size(0) > max_len:
            raise ValueError("not max_len")

        s = x.size(1)
        x_padded = F.pad(x, (0, 0, 0, max_len-x.size(0)))  # 在垂直方向,用0填充max_len-x.size(0)行
        return x_padded[:, :s]

    # 如果传入了maxlen就直接pad,没有就以mel谱图最长的值为maxlen进行pad
    if maxlen:
        output = torch.stack([pad(x, maxlen) for x in inputs])
    else:
        max_len = max(x.size(0) for x in inputs)
        output = torch.stack([pad(x, max_len) for x in inputs])

    return output

loss.py

FastSpeech在训练时会对duration predictor同时训练,结合之前自回归模型均会对最后mel经过postnet处理的前后计算损失,故训练过程中会计算三个损失。loss.py文件中就定义了损失类

import torch
import torch.nn as nn


# 自定义的损失,由两种损失组成,一种分为三块,mel谱图损失分由两圈,即和之前的模型一样,postnet前后都计算损失
class DNNLoss(nn.Module):
    def __init__(self):
        super(DNNLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

    def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target):
        mel_target.requires_grad = False  # 目标信息不需要计算梯度,此处计算mel损失
        mel_loss = self.mse_loss(mel, mel_target)  # postnet之前的损失
        mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)  # postnet之后的损失

        duration_predictor_target.requires_grad = False  # 目标信息不需要计算梯度,此处计算音素持续时间损失
        # 训练duration predictor的损失
        duration_predictor_loss = self.l1_loss(duration_predicted,
                                               duration_predictor_target.float())

        return mel_loss, mel_postnet_loss, duration_predictor_loss

optimizer.py

该文件中封装了一个学习率优化类,其可以实现学习率动态变化和冻结两种更新方式

import numpy as np


# 为学习率方案封装的一个包装器类
class ScheduledOptim():
    ''' A simple wrapper class for learning rate scheduling '''

    def __init__(self, optimizer, d_model, n_warmup_steps, current_steps):
        self._optimizer = optimizer  # 优化器
        self.n_warmup_steps = n_warmup_steps  # warmup的步数
        self.n_current_steps = current_steps  # 训练时的当前步数
        self.init_lr = np.power(d_model, -0.5)  # 学习率

    # 将学习率冻结之后,再进行参数更新
    def step_and_update_lr_frozen(self, learning_rate_frozen):
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = learning_rate_frozen
        self._optimizer.step()

    # 使用设置的学习率方案进行参数更新
    def step_and_update_lr(self):
        self._update_learning_rate()
        self._optimizer.step()

    # 返回当前的学习率
    def get_learning_rate(self):
        learning_rate = 0.0
        for param_group in self._optimizer.param_groups:
            learning_rate = param_group['lr']

        return learning_rate

    # 清除梯度
    def zero_grad(self):
        # print(self.init_lr)
        self._optimizer.zero_grad()

    # 学习率变化规则
    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    # 该学习方案中每步的学习率
    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''
        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()  # 计算当前step的学习率
        # 给所有参数设置学习率
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

train.py

该文件是FastSpeech模型训练过程实现代码,整体流程与普通模型训练一样,需要注意的一点就是数据划分过程中,是分成了一个大batch,其中包含数个real batch,故训练过程是三个for循环的嵌套,与正常的两个for循环嵌套不同,该现象也可以在dataset.py文件中观察到

import torch
import torch.nn as nn
import torch.nn.functional as F

from multiprocessing import cpu_count
import numpy as np
import argparse
import os
import time
import math

from model import FastSpeech
from loss import DNNLoss
from dataset import BufferDataset, DataLoader
from dataset import get_data_to_buffer, collate_fn_tensor
from optimizer import ScheduledOptim
import hparams as hp
import utils


def main(args):
    # Get device
    device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')

    # Define model
    print("Use FastSpeech")
    model = nn.DataParallel(FastSpeech()).to(device)
    print("Model Has Been Defined")
    num_param = utils.get_param_num(model)
    print('Number of TTS Parameters:', num_param)
    # Get buffer
    print("Load data to buffer")
    buffer = get_data_to_buffer()

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=(0.9, 0.98),
                                 eps=1e-9)
    scheduled_optim = ScheduledOptim(optimizer,
                                     hp.decoder_dim,
                                     hp.n_warm_up_step,
                                     args.restore_step)
    fastspeech_loss = DNNLoss().to(device)
    print("Defined Optimizer and Loss Function.")

    # Load checkpoint if exists
    try:
        checkpoint = torch.load(os.path.join(
            hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step %d---\n" % args.restore_step)
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(hp.checkpoint_path):
            os.mkdir(hp.checkpoint_path)

    # Init logger
    if not os.path.exists(hp.logger_path):
        os.mkdir(hp.logger_path)

    # Get dataset
    dataset = BufferDataset(buffer)

    # Get Training Loader
    training_loader = DataLoader(dataset,
                                 batch_size=hp.batch_expand_size * hp.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn_tensor,
                                 drop_last=True,
                                 num_workers=0)
    total_step = hp.epochs * len(training_loader) * hp.batch_expand_size  # 整个训练过程的总步数

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()

    for epoch in range(hp.epochs):
        for i, batchs in enumerate(training_loader):  # 此处是一个大batch
            # real batch start here
            for j, db in enumerate(batchs):  # db才是一个real batch
                start_time = time.perf_counter()

                current_step = i * hp.batch_expand_size + j + args.restore_step + \
                    epoch * len(training_loader) * hp.batch_expand_size + 1

                # Init
                scheduled_optim.zero_grad()

                # Get Data
                character = db["text"].long().to(device)
                mel_target = db["mel_target"].float().to(device)
                duration = db["duration"].int().to(device)
                mel_pos = db["mel_pos"].long().to(device)
                src_pos = db["src_pos"].long().to(device)
                max_mel_len = db["mel_max_len"]

                # Forward
                mel_output, mel_postnet_output, duration_predictor_output = model(character,
                                                                                  src_pos,
                                                                                  mel_pos=mel_pos,
                                                                                  mel_max_length=max_mel_len,
                                                                                  length_target=duration)

                # Cal Loss
                mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(mel_output,
                                                                            mel_postnet_output,
                                                                            duration_predictor_output,
                                                                            mel_target,
                                                                            duration)
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                d_l = duration_loss.item()
                # 记录损失等日志信息
                with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
                    f_total_loss.write(str(t_l)+"\n")

                with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
                    f_mel_loss.write(str(m_l)+"\n")

                with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
                    f_mel_postnet_loss.write(str(m_p_l)+"\n")

                with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss:
                    f_d_loss.write(str(d_l)+"\n")

                # Backward
                total_loss.backward()

                # Clipping gradients to avoid gradient explosion,梯度剪裁
                nn.utils.clip_grad_norm_(
                    model.parameters(), hp.grad_clip_thresh)

                # Update weights,更新权重,可是否设置学习率冻结
                if args.frozen_learning_rate:
                    scheduled_optim.step_and_update_lr_frozen(
                        args.learning_rate_frozen)
                else:
                    scheduled_optim.step_and_update_lr()

                # Print
                if current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch+1, hp.epochs, current_step, total_step)
                    str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
                        m_l, m_p_l, d_l)
                    str3 = "Current Learning Rate is {:.6f}.".format(
                        scheduled_optim.get_learning_rate())
                    str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now-Start), (total_step-current_step)*np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)
                    print(str4)

                    with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
                        f_logger.write(str1 + "\n")
                        f_logger.write(str2 + "\n")
                        f_logger.write(str3 + "\n")
                        f_logger.write(str4 + "\n")
                        f_logger.write("\n")
                # 模型保存
                if current_step % hp.save_step == 0:
                    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                    )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
                    print("save model at step %d ..." % current_step)

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(
                        Time, [i for i in range(len(Time))], axis=None)
                    Time = np.append(Time, temp_value)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--restore_step', type=int, default=0)
    parser.add_argument('--frozen_learning_rate', type=bool, default=False)
    parser.add_argument("--learning_rate_frozen", type=float, default=1e-3)
    args = parser.parse_args()
    main(args)

eval.py

评估文件中就是使用训练好的FastSpeech模型基于文本预测mel谱图,然后使用grif-lim算法和Waveglow模型生成音频

import torch
import torch.nn as nn
import argparse
import numpy as np
import random
import time
import shutil
import os

import hparams as hp
import audio
import utils
import dataset
import text
import model as M
import waveglow

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_DNN(num):  # 加载模型
    checkpoint_path = "checkpoint_" + str(num) + ".pth.tar"
    model = nn.DataParallel(M.FastSpeech()).to(device)
    model.load_state_dict(torch.load(os.path.join(hp.checkpoint_path,
                                                  checkpoint_path))['model'])
    model.eval()
    return model


def synthesis(model, text, alpha=1.0):
    text = np.array(phn)
    text = np.stack([text])
    src_pos = np.array([i+1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    sequence = torch.from_numpy(text).cuda().long()
    src_pos = torch.from_numpy(src_pos).cuda().long()

    with torch.no_grad():
        _, mel = model.module.forward(sequence, src_pos, alpha=alpha)  # mel谱图预测
    return mel[0].cpu().transpose(0, 1), mel.contiguous().transpose(1, 2)


def get_data():
    test1 = "I am very happy to see you again!"
    test2 = "Durian model is a very good speech synthesis!"
    test3 = "When I was twenty, I fell in love with a girl."
    test4 = "I remove attention module in decoder and use average pooling to implement predicting r frames at once"
    test5 = "You can not improve your past, but you can improve your future. Once time is wasted, life is wasted."
    test6 = "Death comes to all, but great achievements raise a monument which shall endure until the sun grows old."
    data_list = list()
    data_list.append(text.text_to_sequence(test1, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test2, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test3, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test4, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test5, hp.text_cleaners))
    data_list.append(text.text_to_sequence(test6, hp.text_cleaners))
    return data_list


if __name__ == "__main__":
    # Test
    WaveGlow = utils.get_WaveGlow()  # 加载Waveglow作为声码器
    parser = argparse.ArgumentParser()
    parser.add_argument('--step', type=int, default=0)
    parser.add_argument("--alpha", type=float, default=1.0)
    args = parser.parse_args()

    print("use griffin-lim and waveglow")
    model = get_DNN(args.step)  # 加载FastSpeech模型
    data_list = get_data()  # 加载文本数据
    for i, phn in enumerate(data_list):
        mel, mel_cuda = synthesis(model, phn, args.alpha)
        if not os.path.exists("results"):
            os.mkdir("results")
        # 使用grif-lim重建音频文件
        audio.tools.inv_mel_spec(
            mel, "results/"+str(args.step)+"_"+str(i)+".wav")
        # 使用Waveglow重建音频文件
        waveglow.inference.inference(
            mel_cuda, WaveGlow,
            "results/"+str(args.step)+"_"+str(i)+"_waveglow.wav")
        print("Done", i + 1)

    s_t = time.perf_counter()
    for i in range(100):
        for _, phn in enumerate(data_list):
            _, _, = synthesis(model, phn, args.alpha)
        print(i)
    e_t = time.perf_counter()
    print((e_t - s_t) / 100.)

本笔记主要记录所选择的fastspeech复现仓库中模型训练相关的代码,结合之前FastSppech论文阅读笔记中的模型部分进行理解;仓库中使用Waveglow作为声码器生成音频,相关的代码就不再解析注释了;至此,FastSpeech复现仓库中主要的代码注释解析就基本完成了。

FastSpeech复现仓库的代码共总结成三篇笔记,除本篇笔记外,分别是fastspeech复现github项目–数据准备fastspeech复现github项目–模型构建,均是对代码进行详细的注释,读者若发现问题或错误,请评论指出,互相学习。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值