WaveNet 代码解析 —— train.py
文章目录
简介
本项目是一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址)
代码解析
全局变量解析
以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。
BATCH_SIZE = 1 # 一批训练集中,样本音频的数量
DATA_DIRECTORY = './VCTK-Corpus' # 下载的VCTK数据集的路径
LOGDIR_ROOT = './logdir' # 训练日志的路径
CHECKPOINT_EVERY = 50 # 保存训练模型的检查点数量
NUM_STEPS = int(1e5) # 训练的总次数
LEARNING_RATE = 1e-3 # 学习率
WAVENET_PARAMS = './wavenet_params.json' # WaveNet 模型的相关参数路径
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) # 当前日期格式化
SAMPLE_SIZE = 100000 # 样本数量大小
L2_REGULARIZATION_STRENGTH = 0 # L2正则化中的系数
SILENCE_THRESHOLD = 0.3 # 音量阈值大小
EPSILON = 0.001 # 精度设置
MOMENTUM = 0.9 # 优化器动量
MAX_TO_KEEP = 5 # 保存的最大检查点数量
METADATA = False # 高级调试信息存储标志
函数解析
main()
下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。
def main():
# 解析命令行功能参数
args = get_arguments()
try:
# 验证并整理与目录有关的参数
directories = validate_directories(args)
except ValueError as e:
print("Some arguments are wrong:")
print(str(e))
return
# 将整理好的文件路径赋给相应变量
logdir = directories['logdir']
restore_from = directories['restore_from']
# 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练
is_overwritten_training = logdir != restore_from
# 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dict
with open(args.wavenet_params, 'r') as f:
wavenet_params = json.load(f)
# 创建线程协调器,多线程协调器相关知识可参考文章地址如下:
# https://blog.csdn.net/weixin_42721167/article/details/112795491
coord = tf.train.Coordinator()
# 从VCTK数据集中加载原始波形
with tf.name_scope('create_inputs'):
# 允许通过指定接近零的阈值跳过静默修剪
silence_threshold = args.silence_threshold if args.silence_threshold > \
EPSILON else None
gc_enabled = args.gc_channels is not None
# 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中
reader = AudioReader(
args.data_dir,
coord,
sample_rate=wavenet_params['sample_rate'],
gc_enabled=gc_enabled,
receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
wavenet_params["dilations"],
wavenet_params["scalar_input"],
wavenet_params["initial_filter_width"]),
sample_size=args.sample_size,
silence_threshold=silence_threshold)
# 准备好的音频出队列
audio_batch = reader.dequeue(args.batch_size)
if gc_enabled:
gc_id_batch = reader.dequeue_gc(args.batch_size)
else:
gc_id_batch = None
# 创建 WaveNet 网络
net = WaveNetModel(
batch_size=args.batch_size,
dilations=wavenet_params["dilations"],
filter_width=wavenet_params["filter_width"],
residual_channels=wavenet_params["residual_channels"],
dilation_channels=wavenet_params["dilation_channels"],
skip_channels=wavenet_params["skip_channels"],
quantization_channels=wavenet_params["quantization_channels"],
use_biases=wavenet_params["use_biases"],
scalar_input=wavenet_params["scalar_input"],
initial_filter_width=wavenet_params["initial_filter_width"],
histograms=args.histograms,
global_condition_channels=args.gc_channels,
global_condition_cardinality=reader.gc_category_cardinality)
# 验证 l2 正则化系数
if args.l2_regularization_strength == 0:
args.l2_regularization_strength = None
# 创建一个 WaveNet 网络并返回自动编码损耗
loss = net.loss(input_batch=audio_batch,
global_condition_batch=gc_id_batch,
l2_regularization_strength=args.l2_regularization_strength)
# 创建对应的优化器
optimizer = optimizer_factory[args.optimizer](
learning_rate=args.learning_rate,
momentum=args.momentum)
# 返回使用 trainable=True 创建的所有变量
trainable = tf.trainable_variables()
optim = optimizer.minimize(loss, var_list=trainable)
# 设置TensorBoard的日志记录
writer = tf.summary.FileWriter(logdir)
writer.add_graph(tf.get_default_graph())
# 收集关于训练的元信息
run_metadata = tf.RunMetadata()
summaries = tf.summary.merge_all()
# 建立会话
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)
# 存储模型检查点的保护程序
# 在创建这个 Saver 对象的时候, max_to_keep 参数表示要保留的最近检查点文件的最大数量,创建新文件时,将删除旧文件
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)
try:
# 恢复训练模型,获取训练步数
saved_global_step = load(saver, sess, restore_from)
if is_overwritten_training or saved_global_step is None:
# 第一个训练步骤将是 saved_global_step + 1,因此我们在这里输入-1表示新的或覆盖的训练
saved_global_step = -1
except:
print("Something went wrong while restoring checkpoint. "
"We will terminate training to avoid accidentally overwriting "
"the previous model.")
raise
# 开启入队线程启动器,详细介绍可参考这篇博客:
# https://blog.csdn.net/weixin_42721167/article/details/112795491
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
reader.start_threads(sess)
step = None
last_saved_step = saved_global_step
try:
# 从恢复模型的节点处开始训练
for step in range(saved_global_step + 1, args.num_steps):
# 获取当前时间
start_time = time.time()
# 当存储标志为 true 且训练次数为50的倍数时存储调试信息
if args.store_metadata and step % 50 == 0:
# 缓慢运行,存储额外的调试信息
print('Storing metadata')
# RunOptions提供配置参数,供SessionRun调用时使用
run_options = tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE)
# 计算日志与自动编码的损失
summary, loss_value, _ = sess.run(
[summaries, loss, optim],
options=run_options,
run_metadata=run_metadata)
# 调用train_writer的add_summary方法将训练过程以及训练步数保存
writer.add_summary(summary, step)
# 记录CPU/内存使用情况
writer.add_run_metadata(run_metadata,
'step_{:04d}'.format(step))
# Tensorflow的Timeline模块是用于描述张量图一个工具,可以记录在会话中每个操作执行时间和资源分配及消耗的情况
tl = timeline.Timeline(run_metadata.step_stats)
# 加载文件路径,打开文件,写入日志
timeline_path = os.path.join(logdir, 'timeline.trace')
with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace_format(show_memory=True))
else:
# 在不保存模型的训练步数里,保存训练日志到 Tensorboard
summary, loss_value, _ = sess.run([summaries, loss, optim])
writer.add_summary(summary, step)
# 计算并打印训练一次的时间与结果
duration = time.time() - start_time
print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
.format(step, loss_value, duration))
# 每隔输入的检查点间隔存储一次训练模型
if step % args.checkpoint_every == 0:
save(saver, sess, logdir, step)
last_saved_step = step
except KeyboardInterrupt:
# 在 ctrl+C 显示之后引入一个换行符,这样保存消息就在它自己的行上了
print()
finally:
# 若训练到了更多步
if step > last_saved_step:
save(saver, sess, logdir, step)
coord.request_stop()
coord.join(threads)
get_arguments()
下面这段代码主要是获取命令行参数。
运用 python 中的 argparse 模块对我们输入的命令行进行解析。
def get_arguments():
def _str_to_bool(s):
""" 将string转换为bool """
""" 传入的字符串被限制为'true'或'false' """
if s.lower() not in ['true', 'false']:
raise ValueError('Argument needs to be a '
'boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]
# 创建解析器,解析的功能参数作为 WaveNet 的实例
parser = argparse.ArgumentParser(description='WaveNet example network')
# 添加可选功能参数: --batch_size; 该参数含义为: 一次要处理的 wav 文件数量
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')
# 添加可选功能参数: --data_dir; 该参数含义为: VCTK数据集的文件路径
parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,
help='The directory containing the VCTK corpus.')
# 添加可选功能参数: --store_metadata; 该参数含义为: 高级调试信息存储标志
parser.add_argument('--store_metadata', type=bool, default=METADATA,
help='Whether to store advanced debugging information '
'(execution time, memory consumption) for use with '
'TensorBoard. Default: ' + str(METADATA) + '.')
# 添加可选功能参数: --logdir; 该参数含义为: 存储 TensorBoard 日志信息的文件路径;
# 需要注意: 该参数不能与'--logdir_root'或'--restore_from'一起使用
parser.add_argument('--logdir', type=str, default=None,
help='Directory in which to store the logging '
'information for TensorBoard. '
'If the model already exists, it will restore '
'the state and will continue training. '
'Cannot use with --logdir_root and --restore_from.')
# 添加可选功能参数: --logdir_root; 该参数含义为: 放置日志输出和生成模型的文件路径,存放在带有日期的子目录下
# 需要注意: 该参数不能与'--logdir'一起使用
parser.add_argument('--logdir_root', type=str, default=None,
help='Root directory to place the logging '
'output and generated model. These are stored '
'under the dated subdirectory of --logdir_root. '
'Cannot use with --logdir.')
# 添加可选功能参数: --restore_from; 该参数含义为: 恢复模型的目录,能创建带有日期的子目录
# 需要注意: 该参数不能与'--logdir'一起使用
parser.add_argument('--restore_from', type=str, default=None,
help='Directory in which to restore the model from. '
'This creates the new model under the dated directory '
'in --logdir_root. '
'Cannot use with --logdir.')
# 添加可选功能参数: --checkpoint_every; 该参数含义为: 存放训练模型的检查点间隔
parser.add_argument('--checkpoint_every', type=int,
default=CHECKPOINT_EVERY,
help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')
# 添加可选功能参数: --num_steps; 该参数含义为: 训练的次数
parser.add_argument('--num_steps', type=int, default=NUM_STEPS,
help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')
# 添加可选功能参数: --learning_rate; 该参数含义为: 训练的学习率
parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')
# 添加可选功能参数: --wavenet_params; 该参数含义为: WaveNet 模型的相关参数
parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,
help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')
# 添加可选功能参数: --sample_size; 该参数含义为: 使用的样本数量
parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,
help='Concatenate and cut audio samples to this many '
'samples. Default: ' + str(SAMPLE_SIZE) + '.')
# 添加可选功能参数: --l2_regularization_strength; 该参数含义为: L2正则化的系数
parser.add_argument('--l2_regularization_strength', type=float,
default=L2_REGULARIZATION_STRENGTH,
help='Coefficient in the L2 regularization. '
'Default: False')
# 添加可选功能参数: --silence_threshold; 该参数含义为: 音量阈值限制
parser.add_argument('--silence_threshold', type=float,
default=SILENCE_THRESHOLD,
help='Volume threshold below which to trim the start '
'and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')
# 添加可选功能参数: --optimizer; 该参数含义为: 优化器选择
parser.add_argument('--optimizer', type=str, default='adam',
choices=optimizer_factory.keys(),
help='Select the optimizer specified by this option. Default: adam.')
# 添加可选功能参数: --momentum; 该参数含义为: 优化器动量大小
parser.add_argument('--momentum', type=float,
default=MOMENTUM, help='Specify the momentum to be '
'used by sgd or rmsprop optimizer. Ignored by the '
'adam optimizer. Default: ' + str(MOMENTUM) + '.')
# 添加可选功能参数: --histograms; 该参数含义为: 直方图汇总存储标志
parser.add_argument('--histograms', type=_str_to_bool, default=False,
help='Whether to store histogram summaries. Default: False')
# 添加可选功能参数: --gc_channels; 该参数含义为: 全局条件通道数量
parser.add_argument('--gc_channels', type=int, default=None,
help='Number of global condition channels. Default: None. Expecting: Int')
# 添加可选功能参数: --max_checkpoints; 该参数含义为: 最大训练模型保存检查点数
parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,
help='Maximum amount of checkpoints that will be kept alive. Default: '
+ str(MAX_TO_KEEP) + '.')
# 把parser中设置的所有"add_argument"给返回到args子类实例中并返回
return parser.parse_args()
validate_directories(args)
下面这段代码主要工作是:验证当前的几个目录是否冲突,将输入的目录参数规范化。
def validate_directories(args):
""" 验证和整理与目录相关的参数 """
# 验证接断
# logdir 与 logdir_root 参数不能同时存在
if args.logdir and args.logdir_root:
raise ValueError("--logdir and --logdir_root cannot be "
"specified at the same time.")
# logdir 与 restore_from 参数不能同时存在
if args.logdir and args.restore_from:
raise ValueError(
"--logdir and --restore_from cannot be specified at the same "
"time. This is to keep your previous model from unexpected "
"overwrites.\n"
"Use --logdir_root to specify the root of the directory which "
"will be automatically created with current date and time, or use "
"only --logdir to just continue the training from the last "
"checkpoint.")
# 整理阶段
# 为 logdir_root 参数赋予给定的值或是默认值
logdir_root = args.logdir_root
if logdir_root is None:
logdir_root = LOGDIR_ROOT
# 为 logdir 参数赋予给定的值或是 logdir_root 参数的默认值
logdir = args.logdir
if logdir is None:
logdir = get_default_logdir(logdir_root)
print('Using default logdir: {}'.format(logdir))
# 为 restore_from 参数赋予给定的值或是 logdir 参数的值
restore_from = args.restore_from
if restore_from is None:
# args.logdir and args.restore_from are exclusive,
# so it is guaranteed the logdir here is newly created.
restore_from = logdir
# 将验证并整理好的目录参数打包返回
return {
'logdir': logdir,
'logdir_root': args.logdir_root,
'restore_from': restore_from
}
get_default_logdir(logdir_root)
下面这段代码主要工作是:在给定的日志目录下,创建训练文件夹,再创建以带有当前日期时间的文件路径,并将该路径返回
def get_default_logdir(logdir_root):
# 使用路径拼接函数 os.path.join() 在给定的目录下创建'train'目录
# 进而创建以当前日期时间为名的子目录,格式为:{0:%Y-%m-%dT%H-%M-%S}
logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)
return logdir
save(saver, sess, logdir, step)
这段代码主要工作是:将给定的训练结果、会话以及检查点保存到指定的文件路径下
def save(saver, sess, logdir, step):
# 设置保存的模型文件名,将文件路径进行拼接
model_name = 'model.ckpt'
checkpoint_path = os.path.join(logdir, model_name)
print('Storing checkpoint to {} ...'.format(logdir), end="")
# 刷新缓冲区,保证正常输出
sys.stdout.flush()
# 若文件不存在则先创造文件
if not os.path.exists(logdir):
os.makedirs(logdir)
# 保存模型
saver.save(sess, checkpoint_path, global_step=step)
print(' Done.')
load(saver, sess, logdir)
这段代码主要工作是:将指定路径下的模型训练结果恢复到当前会话
def load(saver, sess, logdir):
print("Trying to restore saved checkpoints from {} ...".format(logdir),
end="")
# 从指定路径下返回训练模型以及检查点
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt:
print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path))
# 找到模型,获取检查点
global_step = int(ckpt.model_checkpoint_path
.split('/')[-1]
.split('-')[-1])
print(" Global step was: {}".format(global_step))
print(" Restoring...", end="")
# 恢复最新检查点训练情况
saver.restore(sess, ckpt.model_checkpoint_path)
print(" Done.")
# 返回检查点
return global_step
else:
# 未找到模型,返回空值
print(" No checkpoint found.")
return None
本文还在持续更新中!
欢迎各位大佬交流讨论!