Windows下运行谷歌的im2txt(NIC)模型

前言

一些caption论文在谷歌NIC(下称im2txt)模型改进,原文为《Show and Tell: A Neural Image Caption Generator》以及《Show and Tell: Lessons learned from the 2015 MSCOCO Image Captioning Challenge》想跑这个模型,至少你已经对的文章有所了解,这里就不重复了。

本文仅仅是加载存档跑个demo,整合文件仅仅是整理了存档和inference部分。训练部分还是建议前往官方github按照指导在linux环境下运行(见section1.准备原料)。另外,想建立自己的caption模型不建议在该模型上修改,它太老了,并且有的技术实际上已经不实用了,新的开源模型也有不少。

关于数据集,其原文应该已经引用MSCOCO了。在linux系统下,在墙外且网速自信的,可以试试“research\im2txt\data\download_and_preprocess_mscoco.sh”脚本。

我直接从官网下的压缩包,http://cocodataset.org/#download

正文

已经有人跑过并写了博客,参见:

[1] https://blog.csdn.net/gbbb1234/article/details/70543584

[2] https://blog.csdn.net/sparkexpert/article/details/70846094

但是这是在linux上跑的,一些指令对于使用windows+pycharm的我非常不友好,需要改一些东西,换一些方法。试了好几天坑终于通了,前来填坑。

环境:win10 + Pycharm2017.2.3 + python3.6 + tensorflow1.7.0

先看这里:

这里准备了一份所有东西都调好的整合文件。不想按照步骤走一遍的,可以直接下载:

https://pan.baidu.com/s/1Q0kPEbVl1jIrX9kvZ2_xcA

将压缩包解压到E盘基本可以直接使用(相关依赖已经安装完全的前提下,tensorflow、numpy、nltk等)。

关于下文中的文件路径,尽量不要出现中文,会报错unkown error,这样的错误提示很难调试。

1.准备原料

模型的地址:https://github.com/tensorflow/models/tree/master/research/im2txt

一些博客中的模型地址都发生了错误,原因是谷歌在models文件下建立了research文件,然后把原先的模型放进去,所以那些链接失效了。

原模型的README文档很好的描述了在linux下跑模型,你该做的事情。我没有跑训练过程,只是想加载别人的存档试一下,这样至少有东西可以交差。

关于模型本身的获取:

对于github上的文件,直接下载得到的文件可能包含有html格式。而github网页自己只提供了这个大模型的下载链接,显然你并不想下载谷歌所有的开源模型。

使用以下网址,输入所需文件的链接后点击Download可以直接获得模型中的部分文件:

[3] http://kinolien.github.io/gitzip/

其他原料的获取:

关于官方指定的一些原料,已经获取的可以跳过。windows下的tensorflow、numpy、nltk可以参见:

https://blog.csdn.net/heros_never_die/article/details/79760616

而bazel则是最折腾人的了,但bazel不是必要的。它是一个编译器一样的东西,可以用你的工程生成可执行文件,windows下的话就是生成exe,如果你只是想在pycharm上跑一下的话,可以不下载bazel并跳过关于bazel的各种说明,官方文档提供的linux下运行方法是bazel

bazel下载地址:https://github.com/bazelbuild/bazel/releases

选择bazel-0.11.1-windows-x86_64.exe(162MB)即可。这不是一个安装包,你直接双击运行它也会提示你用cmd或者power shell打开它,它这么大是因为集成了jdk等必要的组件,详细的使用在下文介绍。

已经训练好的存档:

感谢博客[1]的博主给出了模型存档(链接在上文[1]处),下面的链接来自与博客[1]:

https://github.com/withyou1771/im2txt

其中model.ckpt-1000000.index   model.meta-1000000.meta  word_counts.txt是我们需要的文件。

但是一个完整的存档应该包含checkpoint、index、meta、data-00000-of-00001四个文件。data-00000-of-00001在其README文档的download里面,链接是谷歌云的,142M,你可能需要VPN来下载它,这里搬运一下地址:

https://drive.google.com/open?id=0B7k91FBdFbY7eVd1SHprQjdkWms

若你没有VPN,我这里转存了一下,不知道哪天会失效:https://pan.baidu.com/s/1PcEx8ZK66gF8vziHYMKIPw

checkpoint文件是一个名叫checkpoint,没有任何后缀名的文件,你可以创建一个txt文件再把名字改过来。里面应该按照这样的格式填写(对应的路径请自行修改):

model_checkpoint_path: "E:\\research\\im2txt\\data\\model.ckpt-1000000"

all_model_checkpoint_paths: "E:\\research\\im2txt\\data\\model.ckpt-1000000"

最后这四个文件在一起才是完整的存档,请放到同一个文件夹下:

2.相关文件的处理

上面的存档是还需要进一步处理的,直接使用的话编译时会出现报错

“Key LSTM/basic_lstm_cell/bias not found in checkpoin”

这是因为谷歌的代码是在tensorflow1.0上写的,而tensorflow1.2以后,对lstm单元的一些参数名做了更改,导致1.0或之前的存档在1.2及以后不能直接用。解决的办法来自以下博客:

https://blog.csdn.net/lgh0824/article/details/77417848

里面给出了一个存档版本转化的文件,亲测有用,下载地址:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py

这是个github上的文件,下载方法和模型本身一样,用上文[3]的链接

其文档里介绍了指令用法:python checkpoint_convert.py [--write_v1_checkpoint] \

      '/path/to/checkpoint' '/path/to/new_checkpoint'

意思是说,打开cmd,进入到这个py文件所在的目录环境,运行这条指令(确保你已经安装了python并且将它加入了系统环境变量)。其中[--write_v1_checkpoint]是可选参数,填上会输出v1的存档,一般是不需要的。后面两个字符串,第一个是待转化的ckpt的路径,第二个是转换完以后的路径。下图帮助理解:

这样在data2文件夹内就有转换好的文件了。两个路径一样的话会报错,因此建个新文件夹,然后把东西盖回去,或者在运行程序时指定存档路径的时候输入新地址。

我的checkpoint_convert文件放在E:python_poject文件夹里,所以使用E:和cd python_project两个命令就进入了文件夹内。cd是打开文件夹的意思,cd..可以退回上一级目录,跨磁盘需要先输入"磁盘名:",例如E:

3.运行模型

先让我们看一下工程的结构:

里面的WORKSPACE文件和BUILD文件是用来给bazel编译时指明工作空间和依赖关系的。观察BUILD文件可以得知,一些包属于libarary,而”train,evaluate,run_inference“这三个属于binary,是可以运行的。train用来训练模型,run_inference可以运行现有的模型,evaluate是用来评估的。

在run_inference.py文件里面,有着和运行关系比较紧密的几句tf.flags.DEFINE-string函数。这个函数定义了主函数的参数,有点类似于main(argv)。第一个字串是变量名,第二个串是变量默认值,原代码里是空的"",我为了方便加入了自己的路径,第三个串是对变量的说明。

checkpoint_path:存档文件所在路径

vocab_file:词汇表所在路径

input_files:输入的图像,可以有多个jpg文件地址,以逗号隔开。

如果不想修改原代码,需要使用带参方式运行函数,pycharm下带参运行的方法如下:

菜单栏Run->EditConfiguration->Sripct parameters栏里输入参数。

例如(以下三句在一行,空格隔开,--也是一部分,博客排版美观分成三行而已):

--checkpoint_path="E:/python_project/research/im2txt/data" 

--vocab_file="E:/python_project/research/im2txt/data/word_counts.txt" 

--input_files="E:/pic/cat.jpg","E:/pic/pic1.jpg"

如果不输入参数,上面的变量都会加载默认值。上面展示了输入多张图片的方法。

至此,模型就可以运行了,结果会给出三个caption:

4.使用bazel编译

这部分内容不是必要的,bazel可以生成exe文件,以后运行起来就挺方便。

1)首先,research文件夹下面的WORKSPACE.txt文件应该把后缀名去掉,变成没有后缀名的格式,这样bazel才能识别,这坑了我好一会,这可能是window和linux环境区别造成的。

2)然后把下载好的bazel.exe丢到WORKSPACE所在的文件夹里,打开cmd进入这个文件夹,就可以执行相关的bazel指令了。直接输入bazel和bazel help是等效的,可以查看bazel相关的指令信息。

3)输入bazel build -c opt im2txt/run_inference就会在当前目录下输出一系列文件夹,exe文件在bazel-bin里。

在我使用bazel的过程中,出现了一些其他问题。

运行exe后窗口一闪而过,截图发现里面是报错信息(我已经在pycharm运行成功了)。分析报错发现,from im2txt.ops import image_imbedding这样的语句会出错。可能bazel编译的时候和pycharm有一些区别。

解决的办法:

在im2txt.ops的文件夹里建立一个__init__.py的文件,里面输入这样的语句

 

__all__ = ['image_embedding.py','image_embeding_test.py','image_processing.py','inputs']

[]里的内容是该文件夹所有的文件名,这个文件使这个文件夹成了个包,可以被其他模块导入。把相应的import语句改为from ops import image_imbedding。所有类似的import都要改(inference_utils文件夹也需要__inti__.py)。

本身就在im2txt文件夹下的import语句改为import xxx即可。

成功使用exe运行模型后,还是会闪退,因为程序成功运行到底结束,就关闭了窗口。因此在原main函数下加了句os.system("pause"),这样程序窗口会留住,提示按任意键继续。

直接双击exe是以默认路径运行的,我们可以打开cmd,进入run_inference.exe所在目录,输入类似下面的指令:

start run_inference --input_files="E:/pic/pic1.jpg"

意思是带参数运行exe,checkpoint_path和vocab_file没填就是按默认路径,input_files同理可以多文件以逗号隔开。

 

5.训练方法简述

应评论要求追加训练过程的简易教程,内容较多因此简写,具体内容请有兴趣的人士自行研究。《深度学习原理与tensorflow实践》[1]一书中对该模型代码有细致的解读。

1.工程结构关键文件介绍:

inception_v3.ckpt  inception_v3存档文件,训练所必要的组件。

download_and_preprocess_mscoco.sh  

下载及解压mscoco2014并运行build_mscoco_data.py的脚本,需要linux下运行,非必要。

windows下使用相关软件执行可以成功下载图像文件,解压时会报错,脚本停止,可手动解压,标注文件需额外下载。也可以自行下载与解压或使用其他数据集。笔者没有使用COCO,而是使用了VIST数据集的DII部分进行训练。

checkpoint以及model.ckpt.....为存档文件。

word_counts.txt 为单词存档。

inference_utils 为模型预测时使用的相关组件。

configuration.py 参数设置文件,超参数基本都在这改。

evaluate.py 评估文件,计算困惑度,可结合验证集使用。

run_inference.py 预测文件,运行即可生成标注。

show_and_tell_model.py 模型定义文件,核心部分。

train.py 训练文件,运行可执行训练过程。

2.build_mscoco_data.py文件简介

其主要功能为,将训练集和验证集合并后按比例重新划分为训练集、验证集、测试集,并产生词表文件(即word_counts.txt)。

其命令行参数如下,按照(变量名,变量值,备注)格式定义,可按照自己的需要修改:

tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
                       "Training image directory.")
tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
                       "Validation image directory.")

tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
                       "Training captions JSON file.")
tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_val2014.json",
                       "Validation captions JSON file.")

tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")

tf.flags.DEFINE_integer("train_shards", 256,
                        "Number of shards in training TFRecord files.")
tf.flags.DEFINE_integer("val_shards", 4,
                        "Number of shards in validation TFRecord files.")
tf.flags.DEFINE_integer("test_shards", 8,
                        "Number of shards in testing TFRecord files.")

tf.flags.DEFINE_string("start_word", "<S>",
                       "Special word added to the beginning of each sentence.")
tf.flags.DEFINE_string("end_word", "</S>",
                       "Special word added to the end of each sentence.")
tf.flags.DEFINE_string("unknown_word", "<UNK>",
                       "Special word meaning 'unknown'.")
tf.flags.DEFINE_integer("min_word_count", 4,
                        "The minimum number of occurrences of each word in the "
                        "training set for inclusion in the vocabulary.")
tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
                       "Output vocabulary file of word counts.")

tf.flags.DEFINE_integer("num_threads", 8,
                        "Number of threads to preprocess the images.")
ImageMetadata = namedtuple("ImageMetadata",
                           ["image_id", "filename", "captions"])

image_dir 图像路径,支持jpg格式。

captions_file 标注文件,json格式。

output_dir 输出路径,程序将图片和对应的标注打包输出为TFRecord文件,划分为训练集256块,验证集4块,测试集8块。

语言模块的开始标记和结束标记为<S>与</S> ,词表外单词以<UNK>表示,词表单词最低频率4,程序线程数8.

以上均体现在相关参数中,根据英文容易理解。程序额外定义了三元组ImageMetadata。

3.训练文件train.py简介

 

tf.flags.DEFINE_string("input_file_pattern", "/VIST_dataset/SIS_TFRecord/train-?????-of-00256",
                       "File pattern of sharded TFRecord input files.")
tf.flags.DEFINE_string("inception_checkpoint_file", "data/inception_v3.ckpt",
                       "Path to a pretrained inception_v3 model.")
tf.flags.DEFINE_string("train_dir", "data",
                       "Directory for saving and loading model checkpoints.")
tf.flags.DEFINE_boolean("train_inception", False,
                        "Whether to train inception submodel variables.")
tf.flags.DEFINE_integer("number_of_steps", 1000000, "Number of training steps.")
tf.flags.DEFINE_integer("log_every_n_steps", 1,
                        "Frequency at which loss and global step are logged.")

input_file_pattern,训练文件名称模式,build_mscoco_data所生成的文件均命名为类似于train-00001-of-00256,以?????占5个字符表示了名称符合这种模式的所有256个文件,文件所在路径请根据自身情况修改。

inception_checkpoint_file,inception存档文件所在位置

train_dir,训练文件目录,会在此文件夹中保存存档以及tensorboard日志。

train_inception,打开后将允许梯度流到inception部分,在关闭状态下训练100万步后语言模型成型后,打开此设定再进行联合训练200万步,将对模型有小幅度的优化,不做也没关系,具体说明见论文原文。

 

综上,训练过程可概述为:

1.运行build_mscoco_build.py产生词表以及TFRecord文件。(确保你已经有数据集并且修改了文件指向)

2.运行train.py文件读取TFRecord文件进行训练并产生模型存档与tensorboard日志。

   程序每10分存档一次,在窗口显示损失函数指。

3.运行inference.py文件进行摘要生成。

特殊说明:

存档文件换位置时,不仅仅需要在程序中修改读取路径,还需要修改存档文件指向。

使用记事本打开data目录下名为"chekcpoint"的无后缀名文件可以看到其对模型存档的进一步指示

model_checkpoint_path: "E:\\research\\im2txt\\data\\model.ckpt-1000000"

all_model_checkpoint_paths: "E:\\research\\im2txt\\data\\model.ckpt-1000000"

此处必须一同修改,由于存档有四个部分,程序实际查找到的是该文件,该文件进一步指示所包含的存档。

因此导致的程序无法运行,笔者深感抱歉,修改后的程序压缩包已上传,并加入了inception_v3存档,删去了word_cooutx.txt后面多余的一个小空格。

https://pan.baidu.com/s/1Q0kPEbVl1jIrX9kvZ2_xcA

 

参考文献:

[1]喻俨,莫瑜,王琛,等. 深度学习原理与tensorflow实践[M].北京:电子工业出版社, 2017.

评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值