fastspeech2复现github项目--数据准备

在完成FastSpeech2论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/ming024/FastSpeech2。该仓库是基于https://github.com/xcmyz/FastSpeech中的FastSpeech复现代码完成的,很多代码基本一致。作者前期已对该FastSpeech复现仓库进行注释分析,感兴趣的读者可见此专栏

此次的FastSpeech2复现仓库中除了使用LJSpeech数据集训练单speaker模型外,还使用了AISHELL和LibriTTS数据集分别训练中文多speaker和英文多speaker模型。本次解析主要还是针对LJSpeech数据集的数据准备,本仓库中audio路径下的py文件与FastSPeech复现仓库中audio的py文件一致,主要用于从音频文件中提取mel谱图文件,可跳转至此笔记“fastspeech复现github项目–数据准备”了解细节;初次之外,数据准备中主要涉及得到的文件是prepare_align.py以及preprocessor路径下的ljspeech.py和preprocessor.py

在进行数据处理前,先将LJSpeech数据集下载至本地,在FastSpeech2论文中使用强制对齐工具MFA从文本和音频中提取对齐信息,代码解析时使用的是作者提供的已经提取好的对齐信息文件,感兴趣的读者也可以自行下载、安装MFA提取对齐信息。根据仓库作者提供的链接下载的每一个*.TextGrid文件与一个音频对应,其中记录了word_level和phone_level两个级别的文本、对应持续时间(单位为秒)等信息,具体格式如下图所示,主要区别就是phone_level比word_level经精细,颗粒度更小。
在这里插入图片描述

TextGrid中word_level文件类型

在这里插入图片描述

TextGrid中phone_level文件类型

prepare_align.py

该文件就是相当于一个接口,针对不同的数据集调用对应的文件函数进行数据准备,主要就是调用数据集对应的prepare_align函数处理数据

import argparse

import yaml

from preprocessor import ljspeech, aishell3, libritts


def main(config):
    if "LJSpeech" in config["dataset"]:
        ljspeech.prepare_align(config)
    if "AISHELL3" in config["dataset"]:
        aishell3.prepare_align(config)
    if "LibriTTS" in config["dataset"]:
        libritts.prepare_align(config)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="path to preprocess.yaml")  # 加载对应的yaml文件,便于后面添加相应参数
    args = parser.parse_args()

    config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
    main(config)

ljspeech.py文件

虽然该文件只定义了prepare_align函数,但是该函数也只是简单的将LJSpeech数据集中的音频数据和文本数据进行了处理并保存,并没有提取对齐信息。

import os

import librosa
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm

from text import _clean_text


def prepare_align(config):
    in_dir = config["path"]["corpus_path"]  # LJSpeech数据集存储路径
    out_dir = config["path"]["raw_path"]  # 数据转化后的存储路径
    sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
    max_wav_value = config["preprocessing"]["audio"]["max_wav_value"]
    cleaners = config["preprocessing"]["text"]["text_cleaners"]
    speaker = "LJSpeech"
    with open(os.path.join(in_dir, "metadata.csv"), encoding="utf-8") as f:
        for line in tqdm(f):
            parts = line.strip().split("|")  # 分割音频文件路劲名和内容文本
            base_name = parts[0]  # 文件路径名
            text = parts[2]  # 音频对应的内容文本
            text = _clean_text(text, cleaners)  # 使用text库提供的接口结合cleaner对文本进行调整

            wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name))  # 获取完整音频文件路径
            if os.path.exists(wav_path):
                os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
                wav, _ = librosa.load(wav_path, sampling_rate)
                wav = wav / max(abs(wav)) * max_wav_value
                # 将处理之后的wav保存
                wavfile.write(
                    os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
                    sampling_rate,
                    wav.astype(np.int16),
                )
                # 将调整后的文本序列保存
                with open(
                    os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
                    "w",
                ) as f1:
                    f1.write(text)

preprocessor.py

该文件中才是从下载的TextGrid文件中提取每条音频对应的duration、pitch和energy信息;其中的config是通过config/LJSpeech/preprocess.yam文件加载而来。

import os
import random
import json

import tgt
import librosa
import numpy as np
import pyworld as pw
from scipy.interpolate import interp1d
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

import audio as Audio


# 定义处理所有数据的处理类
class Preprocessor:
    def __init__(self, config):
        self.config = config
        self.in_dir = config["path"]["raw_path"]  # 存放原始LJSpeech数据的路径
        self.out_dir = config["path"]["preprocessed_path"]  # 数据存储后的路径
        self.val_size = config["preprocessing"]["val_size"]
        self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"]
        self.hop_length = config["preprocessing"]["stft"]["hop_length"]

        assert config["preprocessing"]["pitch"]["feature"] in [
            "phoneme_level",
            "frame_level",
        ]
        assert config["preprocessing"]["energy"]["feature"] in [
            "phoneme_level",
            "frame_level",
        ]
        # 是否进行pitch_phoneme_averaging
        self.pitch_phoneme_averaging = (
                config["preprocessing"]["pitch"]["feature"] == "phoneme_level")
        # 是否进行energy_phoneme_averaging
        self.energy_phoneme_averaging = (
                config["preprocessing"]["energy"]["feature"] == "phoneme_level")
		# 是否进行正则化
        self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"]
        self.energy_normalization = config["preprocessing"]["energy"]["normalization"]
        # 初始化STFT模块
        self.STFT = Audio.stft.TacotronSTFT(
            config["preprocessing"]["stft"]["filter_length"],
            config["preprocessing"]["stft"]["hop_length"],
            config["preprocessing"]["stft"]["win_length"],
            config["preprocessing"]["mel"]["n_mel_channels"],
            config["preprocessing"]["audio"]["sampling_rate"],
            config["preprocessing"]["mel"]["mel_fmin"],
            config["preprocessing"]["mel"]["mel_fmax"],
        )
	# 提出所需数据
    def build_from_path(self):
        os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True)
        os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True)
        os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True)
        os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True)

        print("Processing Data ...")
        out = list()
        n_frames = 0
        pitch_scaler = StandardScaler()
        energy_scaler = StandardScaler()

        # Compute pitch, energy, duration, and mel-spectrogram
        speakers = {}
        # 下面的一个speaker就是一个文件夹,同一speaker的音频放在同一路径下
        for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))):
            speakers[speaker] = i
            for wav_name in os.listdir(os.path.join(self.in_dir, speaker)):
                if ".wav" not in wav_name:
                    continue

                basename = wav_name.split(".")[0]
                # 基于音频文件的basename构建对应的对齐文件路径名
                tg_path = os.path.join(
                    self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
                )
                if os.path.exists(tg_path):
                    ret = self.process_utterance(speaker, basename)  # 提取单个音频的mel、pitch、energy数据
                    if ret is None:
                        continue
                    else:
                        info, pitch, energy, n = ret  # n是mel谱图序列的总帧数
                    out.append(info)  # 记录info中文本相关的数据,是一个用“|”分割的字符串

                if len(pitch) > 0:
                    pitch_scaler.partial_fit(pitch.reshape((-1, 1)))
                if len(energy) > 0:
                    energy_scaler.partial_fit(energy.reshape((-1, 1)))

                n_frames += n

        print("Computing statistic quantities ...")
        # Perform normalization if necessary
        if self.pitch_normalization:
            pitch_mean = pitch_scaler.mean_[0]
            pitch_std = pitch_scaler.scale_[0]
        else:
            # A numerical trick to avoid normalization...
            pitch_mean = 0
            pitch_std = 1
        if self.energy_normalization:
            energy_mean = energy_scaler.mean_[0]
            energy_std = energy_scaler.scale_[0]
        else:
            energy_mean = 0
            energy_std = 1

        pitch_min, pitch_max = self.normalize(
            os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std
        )
        energy_min, energy_max = self.normalize(
            os.path.join(self.out_dir, "energy"), energy_mean, energy_std
        )

        # Save files
        with open(os.path.join(self.out_dir, "speakers.json"), "w") as f:
            f.write(json.dumps(speakers))

        with open(os.path.join(self.out_dir, "stats.json"), "w") as f:
            stats = {
                "pitch": [
                    float(pitch_min),
                    float(pitch_max),
                    float(pitch_mean),
                    float(pitch_std),
                ],
                "energy": [
                    float(energy_min),
                    float(energy_max),
                    float(energy_mean),
                    float(energy_std),
                ],
            }
            f.write(json.dumps(stats))

        print(
            "Total time: {} hours".format(
                n_frames * self.hop_length / self.sampling_rate / 3600
            )
        )

        random.shuffle(out)
        out = [r for r in out if r is not None]

        # Write metadata,划分训练集文本数据和验证集文本数据
        with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f:
            for m in out[self.val_size:]:
                f.write(m + "\n")
        with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f:
            for m in out[: self.val_size]:
                f.write(m + "\n")

        return out
	# 基于文件路径提取音频文件的mel、pitch、energy、duration数据
    def process_utterance(self, speaker, basename):
        wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename))  # 音频文件路径
        text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename))  # 文本文件路径
        tg_path = os.path.join(
            self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename)
        )

        # Get alignments
        textgrid = tgt.io.read_textgrid(tg_path)  # 读取textgrid对象
        # 数据提取。phone中是textgrid对象中文本转为音素的列表,duration中为音素列表中每个元素对应的mel帧数,即每个音素的持续时间,start为音频开始时间,end为结束时间
        phone, duration, start, end = self.get_alignment(
            textgrid.get_tier_by_name("phones"))
        text = "{" + " ".join(phone) + "}"  # 文本信息拼接成字符串方便存储
        if start >= end:
            return None

        # Read and trim wav files
        wav, _ = librosa.load(wav_path)  # 加载音频
        wav = wav[
              int(self.sampling_rate * start): int(self.sampling_rate * end)
              ].astype(np.float32)

        # Read raw text
        with open(text_path, "r") as f:
            raw_text = f.readline().strip("\n")  # 音频对应文本

        # Compute fundamental frequency
        pitch, t = pw.dio(
            wav.astype(np.float64),
            self.sampling_rate,
            frame_period=self.hop_length / self.sampling_rate * 1000,
        )
        pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate)

        pitch = pitch[: sum(duration)]  # 与总的mel谱图帧数对齐
        if np.sum(pitch != 0) <= 1:
            return None

        # Compute mel-scale spectrogram and energy
        mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT)  # 计算mel谱图
        mel_spectrogram = mel_spectrogram[:, : sum(duration)]
        energy = energy[: sum(duration)]

        if self.pitch_phoneme_averaging:
            # perform linear interpolation,线性插值,就是将pitch序列中为0的值赋值一个合理的数值
            nonzero_ids = np.where(pitch != 0)[0]  # 获取pitch中不为值不为0的索引
            interp_fn = interp1d(
                nonzero_ids,
                pitch[nonzero_ids],
                fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
                bounds_error=False,
            )
            pitch = interp_fn(np.arange(0, len(pitch)))  # 插值后,pitch中为0的部分通过插值得到了补充

            # Phoneme-level average
            pos = 0
            for i, d in enumerate(duration):
                if d > 0:
                    pitch[i] = np.mean(pitch[pos: pos + d])
                else:
                    pitch[i] = 0
                pos += d
            pitch = pitch[: len(duration)]

        if self.energy_phoneme_averaging:
            # Phoneme-level average
            pos = 0
            for i, d in enumerate(duration):
                if d > 0:
                    energy[i] = np.mean(energy[pos: pos + d])
                else:
                    energy[i] = 0
                pos += d
            energy = energy[: len(duration)]

        # Save files
        dur_filename = "{}-duration-{}.npy".format(speaker, basename)
        np.save(os.path.join(self.out_dir, "duration", dur_filename), duration)  # 保存时序时间

        pitch_filename = "{}-pitch-{}.npy".format(speaker, basename)
        np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch)  # 保存pitch

        energy_filename = "{}-energy-{}.npy".format(speaker, basename)
        np.save(os.path.join(self.out_dir, "energy", energy_filename), energy)  # 保存energy

        mel_filename = "{}-mel-{}.npy".format(speaker, basename)
        np.save(
            os.path.join(self.out_dir, "mel", mel_filename),
            mel_spectrogram.T,
        )  # 保存mel谱图

        return (
            "|".join([basename, speaker, text, raw_text]),  # 存储文本形式的数据,字符串
            self.remove_outlier(pitch),  # 去除离群值的pitch序列
            self.remove_outlier(energy),  # 去除离群值的energy序列
            mel_spectrogram.shape[1],  # 记录mel谱图序列帧数
        )

    def get_alignment(self, tier):  # 提取对齐信息
        sil_phones = ["sil", "sp", "spn"]
        # tier中存储的主要内容就是音频的持续时间,以及文中中每个音素对应的持续时间信息
        phones = []  # 音素
        durations = []  # 持续时间
        start_time = 0  # 区间开始时间
        end_time = 0  # 区间结束时间
        end_idx = 0
        for t in tier._objects:  # t的类型是Interval(0.0, 0.04, "P"),第一个开始时间,第二个是结束时间,第三个即为该段对应的文本区间
            s, e, p = t.start_time, t.end_time, t.text

            # Trim leading silences
            if phones == []:
                if p in sil_phones:
                    continue
                else:
                    start_time = s

            if p not in sil_phones:
                # For ordinary phones
                phones.append(p)
                end_time = e
                end_idx = len(phones)  # 记录已记录的音素的个数
            else:
                # For silent phones
                phones.append(p)
            # 记录持续时间,将时间单位秒转换为mel帧数
            durations.append(
                int(
                    np.round(e * self.sampling_rate / self.hop_length)
                    - np.round(s * self.sampling_rate / self.hop_length)
                )
            )

        # Trim tailing silences
        phones = phones[:end_idx]
        durations = durations[:end_idx]

        return phones, durations, start_time, end_time

    def remove_outlier(self, values):  # 删除离群值,使用箱型图的逻辑
        values = np.array(values)
        p25 = np.percentile(values, 25)
        p75 = np.percentile(values, 75)
        lower = p25 - 1.5 * (p75 - p25)
        upper = p75 + 1.5 * (p75 - p25)
        normal_indices = np.logical_and(values > lower, values < upper)

        return values[normal_indices]

    def normalize(self, in_dir, mean, std):
        max_value = np.finfo(np.float64).min
        min_value = np.finfo(np.float64).max
        for filename in os.listdir(in_dir):
            filename = os.path.join(in_dir, filename)
            values = (np.load(filename) - mean) / std
            np.save(filename, values)

            max_value = max(max_value, max(values))
            min_value = min(min_value, min(values))

        return min_value, max_value


if __name__ == '__main__':
    import yaml
    import os
    print(os.getcwd())
    os.chdir('../')
    print(os.getcwd())
    config = yaml.load(open('./config/LJSpeech/preprocess.yaml', "r"), Loader=yaml.FullLoader)
    test = Preprocessor(config)
    test.build_from_path()

本笔记主要记录所选择的FastSpeech复现仓库中数据准备相关的代码,其中主要的步骤是需要从经过MFA工具提去的TextGrid文件中提取音频的文本、duration、pitch和energy信息。TextGrid文件解析使用到tgt和pyworld两个库,也可以结合之前的学习笔记“FastSppech2论文阅读笔记”。本笔记主要是对代码进行详细的注释,读者若发现问题或错误,请评论指出,互相学习。

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值