AI对联生成案例(二)

模型训练

有了处理好的数据,我们就可以进行训练了。你可以选择本地训练或在OpenPAI上训练

OpenPAI上训练

OpenPAI 作为开源平台,提供了完整的 AI 模型训练和资源管理能力,能轻松扩展,并支持各种规模的私有部署、云和混合环境。因此,我们推荐在OpenPAI上训练。

完整训练过程请查阅: 在OpenPAI上训练

本地训练

如果你的本地机器性能较好,也可以在本地训练。

模型训练的代码请参考 train.sh

训练过程依然调用t2t模型训练命令:。具体命令如下:t2t_trainer

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./output
LOG_DIR=${TRAIN_DIR}
DATA_DIR=./data_dir
USR_DIR=./usr_dir

PROBLEM=translate_up2down
MODEL=transformer
HPARAMS_SET=transformer_small

t2t-trainer \
--t2t_usr_dir=${USR_DIR} \
--data_dir=${DATA_DIR} \
--problem=${PROBLEM} \
--model=${MODEL} \
--hparams_set=${HPARAMS_SET} \
--output_dir=${TRAIN_DIR} \
--keep_checkpoint_max=1000 \
--worker_gpu=1 \
--train_steps=200000 \
--save_checkpoints_secs=1800 \
--schedule=train \
--worker_gpu_memory_fraction=0.95 \
--hparams="batch_size=1024" 2>&1 | tee -a ${LOG_DIR}/train_default.log
</code></span></span></span></span>

各项参数的作用和取值分别如下:

  1. t2t_usr_dir:如前一小节所述,指定了处理对联问题的模块所在的目录。

  2. data_dir:训练数据目录

  3. problem:问题名称,即translate_up2down

  4. model:训练所使用的 NLP 算法模型,本案例中使用 transformer 模型

  5. hparams_set:transformer 模型下,具体使用的模型。transformer 的各种模型定义在 tensor2tensor/models/transformer.py 文件夹内。本案例使用 transformer_small 模型。

  6. output_dir:保存训练结果

  7. keep_checkpoint_max:保存 checkpoint 文件的最大数目

  8. worker_gpu:是否使用 GPU,以及使用多少 GPU 资源

  9. train_steps:总训练次数

  10. save_checkpoints_secs:保存 checkpoint 的时间间隔

  11. schedule:将要执行的 方法,比如:train, train_and_evaluate, continuous_train_and_eval,train_eval_and_decode, run_std_servertf.contrib.learn.Expeiment

  12. worker_gpu_memory_fraction:分配的 GPU 显存空间

  13. hparams:定义 batch_size 参数。

好啦,我们输入完命令,点击回车,训练终于跑起来啦!如果你在拥有一块 K80 显卡的机器上运行,只需5个小时就可以完成训练。如果你只有 CPU ,那么你只能多等几天啦。 我们将训练过程运行在 Microsoft OpenPAI 分布式资源调度平台上,使用一块 K80 进行训练。

如果你想利用OpenPAI平台训练,可以查看在OpenPAI上训练

4小时24分钟后,训练完成,得到如下模型文件:

  • 检查站
  • 型号.ckpt-200000.data-00000-of-00003
  • 型号.ckpt-200000.data-00001-of-00003
  • 型号.ckpt-200000.data-00002-of-00003
  • 型号.ckpt-200000.index
  • 型号.ckpt-200000.meta

我们将使用该模型文件进行模型推理。

模型推理

在这一阶段,我们将使用上述训练得到的模型文件进行模型推理,利用上联生成下联。

新建推理脚本文件inference.sh

点击查看 inference.sh 的代码。

在推理之前,需要注意如下几个目录:

  • TRAIN_DIR:上述的训练模型文件存放的目录。
  • DATA_DIR:训练字典文件存放目录,即之前提到的。merge.txt.vocab.clean
  • USR_DIR:自定义问题的存放目录,即之前提到的文件。merge_vocab.py
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>TRAIN_DIR=./output
DATA_DIR=./data_dir
USR_DIR=./usr_dir

DECODE_FILE=./decode_this.txt

PROBLEM=translate_up2down
MODEL=transformer
HPARAMS=transformer_small

BEAM_SIZE=4
ALPHA=0.6

poet=$1
new_chars=""
for ((i=0;i < ${#poet} ;++i))
do
new_chars="$new_chars ${poet:i:1}"
done

echo $new_chars > decode_this.txt

echo "生成中..."

t2t-decoder \
--
  • 41
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值