wav2vec2踩坑之旅5:如何制作一个transformers的数据集
摘要
本文记录了制作transformers数据集的主要过程,以thch30中文ASR数据集为例,模仿librispeech的格式,可用于wav2vec2模型的finetune过程。本文主要解决2个核心问题:
- 如何在transformers中自定义数据集?
- 如何使用本地数据集?
本文按照官方数据集添加向导踩坑,愿对诸君有所帮助。
文章目录
注意:
因为隐私使用 ** 代替用户名,如果 ** 影响了你,请使用自己的路径,不使用~是因为transformers的程序员不计算~
1.制作SLR18 thchs30
thchs30 是经典的中文ASR数据集,在openslr上提供下载。
Step 1.1 创建项目生成代码
首先是自有数据集的制作方法,按照官方向导下载代码,创建代码结构。
#复制代码
git clone https://github.com/<your Github handle>/datasets
cd datasets
git remote add upstream https://github.com/huggingface/datasets.git
#增加新的数据集包
mkdir ./datasets/slr18
#添加说明文件
cp ./templates/README.md ./datasets/slr18/README.md
#生成数据集类文件
cp ./templates/new_dataset_script.py ./datasets/slr18/slr18.py
#使用这个数据集
Step 1.2 编码
经过一大堆的分析之后,完成了数据集脚本如下:
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#datasets/slr18/slr18.py
"""
bostenai 定义的.THCHS-30数据库接口
数据集URL:http://openslr.org/18/
Identifier: SLR18
Summary: A Free Chinese Speech Corpus Released by CSLT@Tsinghua University
Category: Speech
License: Apache License v.2.0
@misc{THCHS30_2015,
title={THCHS-30 : A Free Chinese Speech Corpus},
author={Dong Wang, Xuewei Zhang, Zhiyong Zhang},
year={2015},
url={http://arxiv.org/abs/1512.01882}
}
"""
from __future__ import absolute_import, division, print_function
import os
import fnmatch
from multiprocessing import Pool
from functools import partial
import datasets
_CITATION = """\
@InProceedings{huggingface:dataset,
title = {THCHS-30},
author={bostenai, Inc.
},
year={2021}
}
"""
# 来自http://openslr.org/18/ 和我自己的理解
_DESCRIPTION = """\
THCHS30 is an open Chinese speech database published by Center for Speech and Language Technology (CSLT) at Tsinghua University. The origional recording was conducted in 2002 by Dong Wang, supervised by Prof. Xiaoyan Zhu, at the Key State Lab of Intelligence and System, Department of Computer Science, Tsinghua Universeity, and the original name was 'TCMSD', standing for 'Tsinghua Continuous Mandarin Speech Database'. The publication after 13 years has been initiated by Dr. Dong Wang and was supported by Prof. Xiaoyan Zhu. We hope to provide a toy database for new researchers in the field of speech recognition. Therefore, the database is totally free to academic users. You can cite the data using the following BibTeX entry:
@misc{THCHS30_2015,
title={THCHS-30 : A Free Chinese Speech Corpus},
author={Dong Wang, Xuewei Zhang, Zhiyong Zhang},
year={2015},
url={http://arxiv.org/abs/1512.01882}
本数据集封装支持本地数据源,设置环境变量 SLR18_Corpus 到解压后的数据集可以根目录
export $SLR18_Corpus=/path/to/slr18
解压后的SLR18目录的结构应该是:
data_thchs30 /
data /
*.wav
*.trn
train /
*.wav
*.trn
dev
test
lm_phone
lm_word
}
"""
# 指向openslr,如果需要加速请自行指向其他地址
_HOMEPAGE = "http://openslr.org/18/"
#复制了thch30的license
_LICENSE = "Apache License v.2.0"
# 这里有三种配置,text返回文本/全拼音/声韵母分离
_URLs = {
'thch30': "https://www.openslr.org/resources/18/data_thchs30.tgz",
'pinyin1': "https://www.openslr.org/resources/18/data_thchs30.tgz",
'pinyin2': "https://www.openslr.org/resources/18/data_thchs30.tgz",
}
#主数据类
class Slr18Dataset(datasets.GeneratorBasedBuilder):
"""BostenAI 制作的thch30的数据集封装"""
VERSION = datasets.Version("1.1.0")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="thch30", version=VERSION, description="thch30的基础数据,speech data and transcripts,text对应中文"),
datasets.BuilderConfig(name="pinyin1", version=VERSION, description="thch30的基础数据,text对应全量拼音,声调在后的模式"),
datasets.BuilderConfig(name="pinyin2", version=VERSION, description="thch30的基础数据,text对应声韵母分离的拼音,声调在后的模式"),
]
DEFAULT_CONFIG_NAME = "thch30" # It's not mandatory to have a default configuration. Just use one if it make sense.
def _info(self):
# TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
if self.config.name == "thch30": # This is the name of the configuration selected in BUILDER_CONFIGS above
features = datasets.Features(
{
"id": datasets.Value("string"), #
"file": datasets.Value("string"), #音频文件
"pinyin1": datasets.Value("string"), #完整拼音
"pinyin2": datasets.Value("string"), #声韵母分离
"mandarin": datasets.Value("string"), #汉语文本
"text": datasets.Value("string"), #汉语文本
}
)
elif self.config.name == "pinyin1":
features = datasets.Features(
{
"id": datasets.Value("string"), #
"file": datasets.Value("string"), #音频文件
"pinyin1": datasets.Value("string"), #完整拼音
"pinyin2": datasets.Value("string"), #声韵母分离
"mandarin": datasets.Value("string"), #汉语文本
"text": datasets.Value("string"), #完整拼音
}
)
elif self.config.name == "pinyin2":
features = datasets.Features(
{
"id": datasets.Value("string"), #
"file": datasets.Value("string"), #音频文件
"pinyin1": datasets.Value("string"), #完整拼音
"pinyin2": datasets.Value("string"), #声韵母分离
"mandarin": datasets.Value("string"), #汉语文本
"text": datasets.Value("string"), #声韵母分离
}
)
else:
print("尚未实现的配置!")
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
supervised_keys=None,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
#分析本地路径是否指定
data_dir = os.getenv('SLR18_Corpus')
use_local = False
if data_dir is not None:
data_dir = os.path.expanduser(data_dir)
if os.path.exists(data_dir):
use_local =True
#解压和分析数据
if not use_local:
my_urls = _URLs[self.config.name]
data_dir = dl_manager.download_and_extract(my_urls)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "train"),
"split": "train",
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "test"),
"split": "test"
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(data_dir, "dev"),
"split": "dev",
},
),
]
def _generate_examples(
self, filepath, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
):
""" Yields examples as (key, example) tuples. """
# This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
# The `key` is here for legacy reason (tfds) and is not important in itself.
#因为thch30数据集的train/dev/test的trn是指向上一级的
assert os.path.exists(filepath) , "数据集%s不存在"%(split)
#遍历读取数据
ext = 'wav'
file_paths = [os.path.join(dirpath, f)
for dirpath, dirnames, files in os.walk(filepath )
for f in fnmatch.filter(files, "*.%s"%(ext))]
#并发读取数据
try:
with Pool() as pool:
res = pool.map(partial(slr18Corpus , split = split , config = self.config.name), file_paths)
#处理综合结果
for i, iR in enumerate(res):
if iR is None:
continue
yield iR['id'], iR
except Exception as e:
raise Exception("执行并行逻辑异常!%s"%(e))
def slr18Corpus(item , split , config):
"""
分析一个ITEM
"""
tDir = os.path.dirname(item)
tFname, text = os.path.splitext(item)
tTrn = tFname+'.trn'
try:
if split in ['train', 'test', 'dev']:
with open(tTrn , 'r') as f:
line = f.readline()
tTrn = os.path.realpath(os.path.join(tDir, line))
#读取trn文件
with open(tTrn , 'r') as f:
md = f.readline() #汉字
p1 = f.readline() #全拼音
p2 = f.readline() #声韵母分离
except Exception :
print("加载%s对应的trn文件失败!"%(item ))
return None
key = os.path.splitext(os.path.split(item)[-1])[0] #文件名
example = {
"id": key, #
"file": item, #音频文件
"pinyin1": p1, #完整拼音
"pinyin2": p2, #声韵母分离
"mandarin": md, #汉语文本
'text' : md
}
if config == 'pinyin1':
example['text'] = p1
elif config == 'pinyin2':
example['text'] = p2
return example
说明:
- 这里使用了多进程加载,可能瞬间秒掉你的系统,但是是特别省时间。
- 这里利用多个配置返回不同的text,汉语和拼音,如果能够该后续trms中的数据集缓存机制也可以只返回一种。
- 这里使用了环境变量来支持本地加载,如果想下载而没有删除环境变量,会出现不下载的情况,这不是Bug。
- 这里使用了解压后的本地目录,而不是本地下载压缩包,主要是解压很费电脑。
- 这里没有音频质量检查等预处理操作,消耗时间不明显,多进程的优势可能出不来了。
Step 1.3 测试数据集
网络实在是一般,我就不测试网络版本了,如果遇到Bug我们一起修改。
直接本地测试缓存数据
$ export SLR18_Corpus=/media/**/P02/Data/SLR18/data_thchs30
$ python -c "from datasets import load_dataset;data = load_dataset('/home/**/Documents/workspace/datasets/datasets/slr18/slr18.py','thch30')"
说明: 这里使用全路径而不是互联网的数据集简写,毕竟是开发测试。
调试模式:因为存在多进程,调试时需要将Pool设为1
"""
python -c '''from datasets import load_dataset;data = load_dataset('/home/**/Documents/workspace/datasets/datasets/slr18/slr18.py','thch30')''
"""
from datasets import load_dataset
import os
os.environ["SLR18_Corpus"] = "/media/**/P02/Data/SLR18/data_thchs30"
os.environ["BST_DEBUG"] = 'true'
data = load_dataset('/home/**/Documents/workspace/datasets/datasets/slr18/slr18.py','thch30')
执行成功之后就可以在类似如下的目录看到缓存文件:
/home/**/.cache/huggingface/datasets/slr18_dataset/thch30/1.1.0/c91677773f127666932cef7ebdb8227a6583b45dbb798cecc9bf3e0ffff61a9c
注意 :在这个阶段并不读取音频文件,所以处理会非常的快
Step 1.4 wav2vec2的finetne实验
制作了测试脚本如下:
#!/usr/bin/env bash
python /home/**/Documents/workspace/transformers/examples/research_projects/wav2vec2/run_asr.py \
--output_dir="/home/**/Documents/Projects/transformers/bostenai/od_100h-zh_CN" \
--num_train_epochs="30" \
--per_device_train_batch_size="10" \
--per_device_eval_batch_size="20" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="/home/**/Documents/Projects/transformers/bostenai/100h-zh_CN" \
--fp16 \
--dataset_name="/home/**/Documents/workspace/datasets/datasets/slr18/slr18.py" \
--train_split_name="train" \
--validation_split_name="test" \
--orthography="librispeech" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging \
说明:
- evaluation_strategy 这个steps和epoch看你需要和数据集结构,慎重设置,要不然会大量浪费时间。
- model_name_or_path 是本地的模型缓存目录,具体见之前的博文
- dataset_name 是我们之前自定义的数据集,如果需要本地缓存,清设置环境变量
- orthography 是模拟librispeech的,使用text字段
Error1:缺少wer缓存问题
首次执行将遇到wer缺失问题,如下
Exception "ConnectionError"
Couldn't reach https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/wer/wer.py
File: /home/**/.conda/envs/lSrv39/lib/python3.9/site-packages/datasets/utils/file_utils.py, Line: 616
>>> url
'https://raw.githubusercontent.com/huggingface/datasets/1.5.0/metrics/wer/wer.py'
如果你有互联网则不会遇到,如果你很不幸没有,那么魔改run_asr.py的364行,拷贝源码到安装目录,强制指定加载位置。
#wer_metric = datasets.load_metric("wer")
wer_metric = datasets.load_metric("/home/**/.conda/envs/lSrv39/lib/python3.9/site-packages/datasets/metrics/wer/wer.py")
Error2:缺少库jiwer
Exception "ImportError"
To be able to use this metric, you need to install the following dependencies['jiwer'] using 'pip install jiwer' for instance'
File: /home/**/.conda/envs/lSrv39/lib/python3.9/site-packages/datasets/load.py, Line: 447
Error3 : 由于除零问题导致的计算崩溃
#site-packages/transformers/trainer_utils.py
def speed_metrics(split, start_time, num_samples=None):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
if num_samples is not None:
#修正除零问题
if num_samples !=0:
samples_per_second = 1 / (runtime / num_samples)
else:
samples_per_second = 1
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
return result
Error4:磁盘崩溃
追加磁盘限制 ,修改保存策略,防止磁盘满了
--save_total_limit=3 --save_steps=3000 \
Error5:继续训练部分
继续训练可以设置overwrite_output_dir,但是实际上好像没有作用,这个待续
run_asr.py: error: argument --overwrite_output_dir: Truthy value expected: got =/home/**/Documents/Projects/Fairseq/ModelSerial/t1 but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive).
Step 1.5 缩略集测试
如果私有数据集的文件实在太大,调试过程耗费大量时间在数据预处理阶段,为此我设置一个简单的数据集,配置为simple进行微调。也许你也会有这个烦恼,主要是利用软连接拼合数据库。
#ln -s source/corpus/dir $(cName)_$(split)
sDir=/home/**/Documents/Data/ #原始数据存储位置
dDir=/home/**/Documents/Projects/Fairseq/ModelSerial/t1/Data #项目数据存储位置
#BostenAI数据/home/**/Documents/Data/BostenAI/BSTPlan-patch1
#train
ln -s $sDir/xxx/0016 $dDir/train/xx/0016
#test
ln -s $sDir/xxx/test $dDir/test
#dev
ln -s $sDir/xxx/0064 $dDir/dev/xx/0064
ln -s $sDir/xxx0152 $dDir/dev/xxx/0152
使用命令创建新的数据集映射
$ python -c "from datasets import load_dataset;data = load_dataset('/home/**/Documents/workspace/datasets/datasets/xxx/xxx.py', 'simple')"
在此基础上执行finetune:
#!/usr/bin/env bash
export BST_Corpus=/home/**/Documents/Data/BostenAI
export BST_Corpus_Simple=/home/**/Documents/Data/BostenAI
#环境变量用来指定数据源的位置
#如果你需要使用CPU新则加入--no_cuda 删除--fp16的参数
python /home/**/Documents/workspace/transformers/examples/research_projects/wav2vec2/run_asr.py \
--output_dir=/home/**/Documents/Projects/Fairseq/ModelSerial/t1 --overwrite_output_dir=true \
--num_train_epochs=30 \
--per_device_train_batch_size=1 \
--per_device_eval_batch_size=2 \
--evaluation_strategy=epoch \
--save_total_limit=3 --save_steps=3000 \
--logging_steps=50 \
--learning_rate=5e-4 \
--warmup_steps=3000 \
--model_name_or_path=/home/**/Documents/Projects/transformers/bostenai/100h-zh_CN \
--dataset_name=/home/**/Documents/workspace/datasets/datasets/bostenai_asr/bostenai_asr.py \
--dataset_config_name=simple \
--train_split_name=train \
--validation_split_name=test \
--orthography=librispeech \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging --no_cuda --fp16
#如果你的测试集比较大,那么可以把 --evaluation_strategy=steps --save_steps=1000 设置的大一些
#也可以使用--evaluation_strategy=epoch ,在每个轮次之后进行测试
#dataset_config_name 指定你的数据集的子配置,加入一个较小的测试和训练集
Tips:
- 因为使用本地目录,因此每一次修改bostenai_asr.py 都会导致数据集失效,如果没啥大问题就不要总是加载数据集了
- 以下脚本是自有集的脚本,实例,slr18并未配置simple版
Step1.6 上传
上传的脚本和测试就不研究了,因为。。。。
希望上述内容对大家有用。