Tensorflow nmt的超参数

Tensorflow nmt的超参数  

超参数一般用来定义我们的神经网络的关键参数.  

在tensorflow/nmt这个demo中,我们的超参数在 nmt.nmt 模块中配置.这也导致了nmt.py这个文件的代码行数比较多,我们完全可以把参数的配置放到单独的一个文件中去.nmt.py 这个文件也是整个项目的入口文件.如果你想了解这个demo的整体结构,请查看我的另一篇博客tensorflow/nmt的整体结构, 这就不展开了. 

下面我会列出nmt模型定义的超参数,并且追条解释,希望能加深你对这些参数的理解.  

本demo的超参数使用的是argparse模块进行配置的,如果你喜欢,也可以使用tensorflow中的 tf.app.flags.DEFINE_xxx() 函数来配置,后者是前者的简单封装.  

超参数列表  

首先用表格的形式列出所有的超参数,对他们的解释放在下一小节.  

超参数(hparams)类型(type)默认值(default)简介(help)
--num_unitsint32network size
--num_layersint2network depth
--num_encoder_layersintNoneencoder depth, equal to num_layers if None
--num_decoder_layersinyNonedecoder depth, equal to num_layers if None
--encoder_typestrunione of uni, bi, gnmt
--residualboolFalsewhether to add residual connections
--time_majorboolTruewhether to add time-major mode for dynamic RNN
--num_embeddings_partitionsint0number of partitions for embedding vars
--attentionstr""one of "", luong, scaled_luong, bahdanau, normed_bahdanau
--attention_architecturestrstandardone of standard, gnmt, gnmt_v2
--output_attentionboolTrueonly used in standard attention_architecture
--pass_hidden_stateboolTruewhether to pass encoder’s hidden state to decoder
--optimizerstrsgdone of sgd, adam
--learning_ratefloat1.0adam: 0.001 or 0.0001
--warmup_stepsint0how many steps we inverse-decay learning
--warmup_schemestrt2thow to warmup learning rates
--decay_schemestr""how we decay learning rate
--num_train_stepsint12000num steps to train
--colocate_gradients_with_opsboolTruewhether try colocating gradients with corresponding op
--init_opstruniformone of uniform, glorot_normal, glorot_uniform
--init_weightfloat0.1for uniform init_op, initialize weights
--srcstrNonesource suffix
--tgtstrNonetarget suffix
--train_prefixstrNonetrain prefix
--dev_prefixstrNonedev prefix
--test_prefixstrNonetest prefix
--out_dirstrNonemodel folder
--vocab_prefixstrNonevocab prefix
--emded_prefixstrNonePretrained embedding prefix, should be Glove formated txt files
--sosstr<s>Start-of-sentence symbol
--eosstr</s>End-of-sentence symbol
--share_vocabstrFalsewhether use the same vocab between source and target
--check_special_tokenboolTruewhether check special sos, eos, unk tokens exist in the vocab files
--src_max_lenint50max length of source sequence during training
--tgt_max_lenint50max length of target sequence during training
--src_max_len_inferintNonemax length of source sequence during inference
--tgt_max_len_inferintNonemax length of target sequence during inference
--unit_typestrlstmone of lstm, gru, layer_norm_lstm, nas
--forget_biasfloat1.0forget bias for BasicLSTMCell
--dropoutfloat0.2dropout rate
--max_gradient_normfloat5.0clip gradients to this norm
--batch_sizeint128batch size
--steps_per_statsint100how many training steps to do per stats logging
--max_trainint0limit on the size of training data(0: no limit)
--num_bucketsint5put data into similar-length buckets
--subword_optionstr""one of "", bpe, spm
--num_gpusint1number of gpus in each worker
--log_device_placementboolFalsedebug gpu allocation
--metricsstrbleucomma-separated list of evaluations
--steps_per_external_evalintNonehow many training steps to do per external evaluation
--scopestrNonescope to put variables under
--hparams_pathstrNonepath to hparams json file
--random_seedintNonerandom seed
--override_loadded_hparamsboolFlaseoverride loaded hparams with values specified
--num_keep_ckptsint5max number of checkpoints to keep
--avg_ckptsboolFalseaverage the last N checkpoints for external evaluation
--ckptstr""checkpoint file to load a model for inference
--inference_input_filestrNoneset to the text decode
--inference_liststrNonea comma-separated list of sentence indices
--infer_batch_sizeint32batch size for inference mode
--inference_ouput_filestrNoneoutput file to store decoding results
--inference_ref_filestrNonereference file to compute evaluation scores
--beam_widthint0beam width when using beam search decoder
--length_penalty_weightfloat0.0length penalty for beam search
--sampling_temperaturefloat0.0softmax sampling temperature for inference decoding
--num_translations_per_inputint1number of translations generated for each sentence
--jobidint0task if of the worker
--num_workersint1number of workers(inference only)
--num_inter_threadsint0number of inter_op_parallelism_threads
--num_train_threadsint0number of intra_op_parallelism_threads

逐条详解  

上一小节列出了所有的超参数,接下来我将分组进行更加详细的解释。  

数据相关参数  

本小节介绍数据相关的参数:  
* --src
该参数指定训练数据中,源数据的文件后缀名。举个例子,我们的训练数据是一对逐行一一对应的文本文件,分别为address_train.ocraddress_train.std,那么此时我们需要指定该参数为: --src=ocr  
* --tgt
该参数指定训练数据中,目标数据的文件后缀名,按照上面的举例,我们需要指定该参数为: --tgt=std
* --train_prefix
该参数是train数据文件的前缀,注意 需要包含完整路径 ,路径可以是相对路径,也可以是绝对路径。举个例子,上述例子的两个文件我们放在 /tmp/nmt_model 目录下面,那么该参数需要设置为:--train_prefix=/tmp/nmt_model/address_train,那么train数据的完整路径就是: /tmp/nmt_model/address_train.ocr/tmp/nmt_model/address_train.std
* --dev_prefix
该参数指定dev数据文件的前缀,同--train_prefix类似。举个例子,在 /tmp/nmt_model 目录下面存放我们的dev数据文件 address_dev.ocr 和 address_dev.std,那么该参数应该指定为: --dev_prefix=/tmp/nmt_model/address_dev
* --test_prefix
该参数是test数据文件的前缀,其他和上述--train_prefix--dev_prefix类似。  
* --vocab_prefix
该参数指定的是词典文件的前缀,注意 需要包含完整路径 ,可以是相对路径也可以是绝对路径。举个例子,我们的词典文件为 vocab.ocr 和 vocab.std ,位于 /tmp/nmt_model/ 那么该参数应该指定为:--vocab_prefix=/tmp/nmt_model/vocab,最终的词典路径为 /tmp/nmt_model/vocab.ocr 和 /tmp/nmt_model/vocab.std 。  
* --embed_prefix
该参数指定已经训练好的embedding文件,必须是Glove文件格式。如果没有,使用默认值None。  
* --out_dir
该参数指定模型的保存路径。比如你想保存在 /tmp/ 目录下,那你这样指定:--out_dir=/tmp 。  

超参数的使用  

注意事项  

训练该模型,对机器的要求比较高。本人尝试过使用公司配的开发机器配置如下:  
* Platform Windows 7 x64
* Memory 8G
* CPU intel core i5-6500
* GPU GTX950 2G x1

此配置在我改小了batch_size到32之后,还是报错 Out of memoey.
在服务器配置如下:  
* Platform Ubuntu16.04 amd64
* Memory 32G
* CPU intel core i7-7700k
* GPU GTX1080ti 11G x1

上训练普通模型,并且将batch_size设置成默认的128,可以正常训练,但是此配置训练GNMT模型报错OOM。   

这里也就说明一个小技巧: 改小batch_size可以降低显存使用。

当然其他维度的降低也可以降低显存使用  

扩展--分布式训练  

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值