vits官方gituhb项目--模型训练

在完成VITS论文学习后,对github上的官方仓库进行学习,帮助理解算法实现过程中的一些细节;仓库代码基于pytorch实现,链接为https://github.com/jaywalnut310/vits。论文和代码中都针对单speaker的数据集LJSpeech和多speaker的数据集VCTK进行了训练,本笔记主要针对多speaker设置下的训练代码进行注释解析,主要涉及仓库项目中的train_ms.py文件。

train_ms.py

VITS训练时,使用了混合精度训练,并且设置了对抗训练模式;其中判别器使用了多周期判别器,由多个子判别器组成,并且生成过程损失中还加上了feature_map损失。训练过程中,不是对完整的音频文件进行训练,而是提取一部分音频数据进行训练,进而在计算损失时,也要从ground truth中提取对应部分的数值进行计算。具体的训练代码及注释如下:

import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

import commons
import utils
from data_utils import (
    TextAudioSpeakerLoader,
    TextAudioSpeakerCollate,
    DistributedBucketSampler
)
from models import (
    SynthesizerTrn,
    MultiPeriodDiscriminator,
)
from losses import (
    generator_loss,
    discriminator_loss,
    feature_loss,
    kl_loss
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from text.symbols import symbols

torch.backends.cudnn.benchmark = True
global_step = 0


def main():
    """Assume Single Node Multi GPUs Training Only;只考虑单机多卡训练"""
    assert torch.cuda.is_available(), "CPU training is not allowed."

    n_gpus = torch.cuda.device_count()
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '80000'

    hps = utils.get_hparams()  # 获取参数超参数
    mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))


# run函数中是实际训练代码
def run(rank, n_gpus, hps):
    global global_step
    if rank == 0:
        logger = utils.get_logger(hps.model_dir)
        logger.info(hps)
        utils.check_git_hash(hps.model_dir)
        writer = SummaryWriter(log_dir=hps.model_dir)
        writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))

    dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
    torch.manual_seed(hps.train.seed)
    torch.cuda.set_device(rank)

    train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)  # 加载数据集
    # 分布式的基于桶的sampler
    train_sampler = DistributedBucketSampler(
        train_dataset,
        hps.train.batch_size,
        [32, 300, 400, 500, 600, 700, 800, 900, 1000],  # 桶排序的边界
        num_replicas=n_gpus,
        rank=rank,
        shuffle=True)
    collate_fn = TextAudioSpeakerCollate()
    # 构建训练数据
    train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
                              collate_fn=collate_fn, batch_sampler=train_sampler)
    if rank == 0:  # 在主机上进行验证,即此处是在主机上加载验证数据集
        eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
        eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,
                                 batch_size=hps.train.batch_size, pin_memory=True,
                                 drop_last=False, collate_fn=collate_fn)
    # 生成器,表示文本到音频的整个模型
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).cuda(rank)
    # 多周期的判别器
    net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
    # 生成器的优化器
    optim_g = torch.optim.AdamW(
        net_g.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    # 判别器的优化器
    optim_d = torch.optim.AdamW(
        net_d.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    # 多卡分布式训练,使用DDP把生成器和判别器包裹起来
    net_g = DDP(net_g, device_ids=[rank])
    net_d = DDP(net_d, device_ids=[rank])

    try:  # 尝试加载可能存在的通过训练已经保存的模型参数
        _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
                                                   optim_g)
        _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
                                                   optim_d)
        global_step = (epoch_str - 1) * len(train_loader)
    except:
        epoch_str = 1
        global_step = 0

    # 定义生成器和判别器的学习率schedule
    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)

    scaler = GradScaler(enabled=hps.train.fp16_run)  # 混合精度训练

    for epoch in range(epoch_str, hps.train.epochs + 1):
        if rank == 0:  # 如果为主机,除了参入正常训练参数,还需要传入验证数据集、logger等其他参数
            train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
                               [train_loader, eval_loader], logger, [writer, writer_eval])
        else:
            train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
                               [train_loader, None], None, None)
        # 更新学习率
        scheduler_g.step()
        scheduler_d.step()


# 训练和验证函数
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
    net_g, net_d = nets  # 生成器和判别器
    optim_g, optim_d = optims
    scheduler_g, scheduler_d = schedulers
    train_loader, eval_loader = loaders
    if writers is not None:
        writer, writer_eval = writers

    train_loader.batch_sampler.set_epoch(epoch)  # 设置train_loader中桶排序的随机种子,随机种子是每次的epoch,用于打乱数据,但也可以复现
    global global_step

    net_g.train()
    net_d.train()
    for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(train_loader):
        x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
        spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
        y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
        speakers = speakers.cuda(rank, non_blocking=True)

        with autocast(enabled=hps.train.fp16_run):  # 模型计算部分进行半精度计算
            # 对整个音频序列采样进行训练,不是把整个音频序列送入进行训练,降低训练所需资源,ids_slice就对应采样后频谱的id
            # y_hat是预测的音频波形,l_length是时长预测器的损失,attn是对齐矩阵或时长信息
            y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
            (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)

            # 将线性谱转为mel谱图,便于后续计算L_recon
            mel = spec_to_mel_torch(
                spec,
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.mel_fmin,
                hps.data.mel_fmax)
            # 以ids_slice作为指导,采样对应窗口的mel谱图作为target
            y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
            # 从生成的音频波形y_hat中提取对应的mel谱图
            y_hat_mel = mel_spectrogram_torch(
                y_hat.squeeze(1),
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.hop_length,
                hps.data.win_length,
                hps.data.mel_fmin,
                hps.data.mel_fmax)
            # 从完整的音频数据中以ids_slice获取对应窗口部分的音频数据;判别器判别时需要真实波形数据
            y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size)  # slice

            # Discriminator;y_d_hat_r, y_d_hat_g记录所有子判别器对batch中真实波形y和生成波形y_hat的判别结果
            y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())

            with autocast(enabled=False):  # 损失的计算不进行半精度计算
                loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)  # 判别器的损失
                loss_disc_all = loss_disc
        # 判别器更新
        optim_d.zero_grad()
        scaler.scale(loss_disc_all).backward()
        scaler.unscale_(optim_d)  # 梯度剪裁前先进行unscale
        grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)  # 梯度剪裁
        scaler.step(optim_d)

        with autocast(enabled=hps.train.fp16_run):
            # Generator
            # 将生成的波形和真实波形分别送入到判别器中,希望两者在判别器的中间特征尽可能保持一致,即论文中的L_{fm},需要fmap_r, fmap_g进行计算
            y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
            with autocast(enabled=False):
                loss_dur = torch.sum(l_length.float())  # 时间预测器loss,直接求和
                loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel  # 重构loss,论文中系数c_mel为45
                # 计算模型基于文本学习到的先验分布和从音频线性谱图中学习到的后验分布之间的KL散度,系数c_kl为1
                loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

                loss_fm = feature_loss(fmap_r, fmap_g)  # feature map 的loss
                loss_gen, losses_gen = generator_loss(y_d_hat_g)  # 生成器的对抗loss
                loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
        # 生成器更新
        optim_g.zero_grad()
        scaler.scale(loss_gen_all).backward()
        scaler.unscale_(optim_g)
        grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
        scaler.step(optim_g)
        scaler.update()
        # 主卡上进行loss打印、记录和模型验证、保存
        if rank == 0:
            if global_step % hps.train.log_interval == 0:
                lr = optim_g.param_groups[0]['lr']
                losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
                logger.info('Train Epoch: {} [{:.0f}%]'.format(
                    epoch,
                    100. * batch_idx / len(train_loader)))
                logger.info([x.item() for x in losses] + [global_step, lr])

                scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
                               "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}  # 记录损失和梯度
                scalar_dict.update(
                    {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})

                scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
                scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
                scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
                # 以图像的形式记录mel谱图和对齐信息
                image_dict = {
                    "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
                    "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
                    "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
                    "all/attn": utils.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy())
                }
                # 调用定义的tensorboard的writer记录上述信息
                utils.summarize(
                    writer=writer,
                    global_step=global_step,
                    images=image_dict,
                    scalars=scalar_dict)

            if global_step % hps.train.eval_interval == 0:
                evaluate(hps, net_g, eval_loader, writer_eval)  # 验证
                # 保存生成器和判别器的参数
                utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,
                                      os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
                utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,
                                      os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
        global_step += 1

    if rank == 0:
        logger.info('====> Epoch: {}'.format(epoch))


# 验证
def evaluate(hps, generator, eval_loader, writer_eval):
    generator.eval()  # 验证模式
    with torch.no_grad():
        for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):
            x, x_lengths = x.cuda(0), x_lengths.cuda(0)
            spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
            y, y_lengths = y.cuda(0), y_lengths.cuda(0)
            speakers = speakers.cuda(0)

            # remove else
            x = x[:1]
            x_lengths = x_lengths[:1]
            spec = spec[:1]
            spec_lengths = spec_lengths[:1]
            y = y[:1]
            y_lengths = y_lengths[:1]
            speakers = speakers[:1]
            break
        y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)  # 基于文本生成音频
        y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length

        # 提取真实的mel谱图
        mel = spec_to_mel_torch(
            spec,
            hps.data.filter_length,
            hps.data.n_mel_channels,
            hps.data.sampling_rate,
            hps.data.mel_fmin,
            hps.data.mel_fmax)
        # 从预测的音频的提取mel谱图
        y_hat_mel = mel_spectrogram_torch(
            y_hat.squeeze(1).float(),
            hps.data.filter_length,
            hps.data.n_mel_channels,
            hps.data.sampling_rate,
            hps.data.hop_length,
            hps.data.win_length,
            hps.data.mel_fmin,
            hps.data.mel_fmax
        )
    image_dict = {
        "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
    }
    audio_dict = {
        "gen/audio": y_hat[0, :, :y_hat_lengths[0]]
    }
    if global_step == 0:
        image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
        audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]})

    # 记录信息
    utils.summarize(
        writer=writer_eval,
        global_step=global_step,
        images=image_dict,
        audios=audio_dict,
        audio_sampling_rate=hps.data.sampling_rate
    )
    generator.train()


if __name__ == "__main__":
    main()

losses.py

从论文中可知,本模型训练过程中涉及很多的损失,对抗训练过程中,判别器是常规的判别器损失结构,但是使用的是多周期判别器,由多个子判别器组成;生成器的损失,包括mel重建损失、KL散度、时长预测器损失、对抗训练生成损失以及特征图损失,其中时长预测器损失在模型forward函数中直接计算、mel重建损失是直接计算L1损失,剩下的四种损失在losses.py文件中定义,代码如下:

import torch
from torch.nn import functional as F

import commons


# 计算对抗训练中生成波形和真实波形在判别器中间特征之间的距离损失
def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):  # 遍历真实波形和预测波形在判别器每层的特征图
        for rl, gl in zip(dr, dg):
            rl = rl.float().detach()
            gl = gl.float()
            loss += torch.mean(torch.abs(rl - gl))  # 计算L1损失

    return loss * 2


# 判别器损失
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):  # 遍历多个子判别器的判别结果
        dr = dr.float()  # 一个子判别器对真实波形的判别结果
        dg = dg.float()  # 一个子判别器对生成波形的判别结果
        r_loss = torch.mean((1 - dr) ** 2)  # 真实波形的判别结果越接近于1越好
        g_loss = torch.mean(dg ** 2)  # 生成波形的判别结果越接近于0越好
        loss += (r_loss + g_loss)  # 累加当前子判别器的损失
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses


# 生成器的对抗损失,就是将生成器生成的波形经过判别器后的输出与1计算距离损失,L2损失
def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        dg = dg.float()
        l = torch.mean((1 - dg) ** 2)
        gen_losses.append(l)
        loss += l

    return loss, gen_losses


# 先验分布和后验分布之间的KL散度
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
    """
    z_p, logs_q: [b, h, t_t]
    m_p, logs_p: [b, h, t_t]
    """
    z_p = z_p.float()
    logs_q = logs_q.float()
    m_p = m_p.float()
    logs_p = logs_p.float()
    z_mask = z_mask.float()

    kl = logs_p - logs_q - 0.5
    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2. * logs_p)
    kl = torch.sum(kl * z_mask)
    l = kl / torch.sum(z_mask)
    return l

本笔记主要记录vits官方仓库中模型训练相关代码,其中涉及到的一些辅助函数,如果有必要后续会进行补充。本笔记主要是对代码进行详细的注释,读者若发现问题或错误,请评论指出,互相学习。

  • 7
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
VITS(Variational Inference for Text-to-Speech)是一种端到端的文本到语音合成方法,它可以将文本转化为自然流畅的语音。VITS-Fast Fine-Tuning是对VITS模型进行快速微调的方法。 在传统的语音合成任务中,需要大量的语音对齐标注数据来训练模型。然而,这个过程非常耗时和昂贵。VITS-Fast Fine-Tuning的目标就是通过少量的标注数据来快速微调已有的VITS模型,以在新的任务上取得更好的性能。 VITS-Fast Fine-Tuning方法的关键在于使用变分推断(variational inference)来构建先验和后验分布。通过这个方法,我们可以使用其他大型语音合成数据集训练好的模型作为先验分布,然后使用少量目标任务的标注数据来估计后验分布。这样一来,我们就能够在新任务上快速微调VITS模型。 具体而言,VITS-Fast Fine-Tuning的过程分为两步。第一步是预训练,它使用大型语音数据集来训练VITS模型,并生成一个先验分布。第二步是微调,它使用目标任务的标注数据来调整VITS模型的参数,以获得更好的性能。由于预训练的先验分布已经包含了一定的知识,微调的过程可以更快速和高效。 总之,VITS-Fast Fine-Tuning是一种用于快速微调VITS模型的方法。它利用变分推断和预训练的先验分布,通过少量目标任务的标注数据来优化模型性能。这个方法可以加快语音合成模型训练过程,降低训练的时间和成本。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值