模型训练
有了处理好的数据,我们就可以进行训练了。你可以选择本地训练或在OpenPAI上训练。
OpenPAI上训练
OpenPAI 作为开源平台,提供了完整的 AI 模型训练和资源管理能力,能轻松扩展,并支持各种规模的私有部署、云和混合环境。因此,我们推荐在OpenPAI上训练。
完整训练过程请查阅: 在OpenPAI上训练
本地训练
如果你的本地机器性能较好,也可以在本地训练。
模型训练的代码请参考 train.sh。
训练过程依然调用t2t模型训练命令:。具体命令如下:t2t_trainer
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./output
LOG_DIR=${TRAIN_DIR}
DATA_DIR=./data_dir
USR_DIR=./usr_dir
PROBLEM=translate_up2down
MODEL=transformer
HPARAMS_SET=transformer_small
t2t-trainer \
--t2t_usr_dir=${USR_DIR} \
--data_dir=${DATA_DIR} \
--problem=${PROBLEM} \
--model=${MODEL} \
--hparams_set=${HPARAMS_SET} \
--output_dir=${TRAIN_DIR} \
--keep_checkpoint_max=1000 \
--worker_gpu=1 \
--train_steps=200000 \
--save_checkpoints_secs=1800 \
--schedule=train \
--worker_gpu_memory_fraction=0.95 \
--hparams="batch_size=1024" 2>&1 | tee -a ${LOG_DIR}/train_default.log
</code></span></span></span></span>
各项参数的作用和取值分别如下:
-
t2t_usr_dir
:如前一小节所述,指定了处理对联问题的模块所在的目录。 -
data_dir
:训练数据目录 -
problem
:问题名称,即translate_up2down -
model
:训练所使用的 NLP 算法模型,本案例中使用 transformer 模型 -
hparams_set
:transformer 模型下,具体使用的模型。transformer 的各种模型定义在 tensor2tensor/models/transformer.py 文件夹内。本案例使用 transformer_small 模型。 -
output_dir
:保存训练结果 -
keep_checkpoint_max
:保存 checkpoint 文件的最大数目 -
worker_gpu
:是否使用 GPU,以及使用多少 GPU 资源 -
train_steps
:总训练次数 -
save_checkpoints_secs
:保存 checkpoint 的时间间隔 -
schedule
:将要执行的 方法,比如:train, train_and_evaluate, continuous_train_and_eval,train_eval_and_decode, run_std_servertf.contrib.learn.Expeiment
-
worker_gpu_memory_fraction
:分配的 GPU 显存空间 -
hparams
:定义 batch_size 参数。
好啦,我们输入完命令,点击回车,训练终于跑起来啦!如果你在拥有一块 K80 显卡的机器上运行,只需5个小时就可以完成训练。如果你只有 CPU ,那么你只能多等几天啦。 我们将训练过程运行在 Microsoft OpenPAI 分布式资源调度平台上,使用一块 K80 进行训练。
如果你想利用OpenPAI平台训练,可以查看在OpenPAI上训练。
4小时24分钟后,训练完成,得到如下模型文件:
- 检查站
- 型号.ckpt-200000.data-00000-of-00003
- 型号.ckpt-200000.data-00001-of-00003
- 型号.ckpt-200000.data-00002-of-00003
- 型号.ckpt-200000.index
- 型号.ckpt-200000.meta
我们将使用该模型文件进行模型推理。
模型推理
在这一阶段,我们将使用上述训练得到的模型文件进行模型推理,利用上联生成下联。
新建推理脚本文件inference.sh
点击查看 inference.sh 的代码。
在推理之前,需要注意如下几个目录:
TRAIN_DIR
:上述的训练模型文件存放的目录。DATA_DIR
:训练字典文件存放目录,即之前提到的。merge.txt.vocab.clean
USR_DIR
:自定义问题的存放目录,即之前提到的文件。merge_vocab.py
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./output
DATA_DIR=./data_dir
USR_DIR=./usr_dir
DECODE_FILE=./decode_this.txt
PROBLEM=translate_up2down
MODEL=transformer
HPARAMS=transformer_small
BEAM_SIZE=4
ALPHA=0.6
poet=$1
new_chars=""
for ((i=0;i < ${#poet} ;++i))
do
new_chars="$new_chars ${poet:i:1}"
done
echo $new_chars > decode_this.txt
echo "生成中..."
t2t-decoder \
--