视觉问答学习(一)——视觉问答的动态记忆网络DMN+(tensorflow实现)

一、背景

本文实现的模型来自于论文:《Dynamic Memory Networks for Visual and Textual Question Answering》

之前做了两个月杂活,最近该上手实验了,这里先从别人的实验开始学习。这篇是视觉问答实验的第一篇。

实验数据比较多,图片用的是COCO的,文本标注是VQA 1.0的,另外还用到了vgg16,所以需要准备的东西也非常多。

二、论文简介

论文的下载链接为:https://arxiv.org/pdf/1603.01417.pdf

先给出论文的摘要:

Neural network architectures with memory and attention mechanisms exhibit certain reasoning capabilities required for question answering. One such architecture, the dynamic memory network (DMN), obtained high accuracy on a variety of language tasks. However, it was not shown whether the architecture achieves strong results for question answering when supporting facts are not marked during training or whether it could be applied to other modalities such as images. Based on an analysis of the DMN, we propose several improvements to its memory and input modules. Together with these changes we introduce a novel input module for images in order to be able to answer visual questions. Our new DMN+ model improves the state of the art on both the Visual Question Answering dataset and the bAbI-10k text question-answering dataset without supporting fact supervision.

对于视觉问答,神经网络的记忆模块和注意力机制展示出了很好的推理能力。一种动态记忆网络(DMN),在多种语言任务中获得了很高的精度。然而,这并未体现出,当支持的事实在训练中没有被标记时,该结构是否在回答问题上取得了准确的结果,或者它是否可以应用于其他模态,比如图像。基于DMN的分析,作者对于它的记忆模块和输入模块提出了多个改进。结合这些改进,作者引入一种新的图像输入模块以实现视觉问答。作者提出的DMN+模型,在没有任何监督信息的条件下,在VQA数据集和bAbI-10k文本问答数据集上都表现出了最好的效果。

记忆机制能够一定程度的处理推理关系,注意力机制也成功用在了机器翻译和看图说话的模型中。而动态记忆网络DMN(Dynamic Memory Network)就是一个同时使用了这两种机制的神经网络模型。本文分了了DMN的组成,主要是输入模块和记忆模块,以改善问答的效果。作者的主要创新在于:

We propose a new input module which uses a two level encoder with a sentence reader and input fusion layer to allow for information flow between sentences. (提出了一个输入模块。该模块使用了由句子阅读器和输入融合层构成的二级编码器,以实现信息在句子间的流动)

For the memory, we propose a modification to gated recurrent units (GRU). The new GRU formulation incorporates attention gates that are computed using global knowledge over the facts.(对于记忆模块,使用了一个修改的GRU单元。新的GRU公式包含了注意力门,该注意力门由信息中的全局知识计算得到)

In addition, we introduce a new input module to represent images.(另外还提出了一个输入模块来表示图像)

这个新的DMN+模型的最大优势在于,训练过程中不需要任何由标签的支持信息(比如回答一个特殊问题的有关信息)。

下图展示了修改的记忆模块对于文本问答和视觉问答的改进:

1. 动态记忆网络DMN(Dynamic Memory Networks)

DMN是一个用于问答的模型,它由多个不同的模块组成,比如输入表示模块,记忆模块等。部分模块如前面的Figure 1所表示的那样:

输入模块(Input Module):该模块用于处理输入的数据,主要是将提问的问题转化为一组向量,并称为facts,记为F = [f_{1}, ······ , f_{N}],N表示所有facts的数量,这组向量可以用于后续的模块中。该输入模块在文本问答中是由许多GRU单元组成,若用x_{i}表示第i步的输入,前一步的隐含状态(hidden state)为h_{i-1},那么当前步的隐含状态h_{i}=GRU(x_{i}, h_{i-1})可以由下式计算出:

问题模块(Question Module):该模块用于计算问题q的向量表示,其中q是问题中的单词在GRU中的最终隐含状态。

情景记忆模块(Episodic Memory Module):情景记忆旨在从输入的facts中检索能够回答问题q的信息。为了改进对输入问题的理解,特别是当问题需要涉及推理,情景记忆模块会多次传递输入,每次传递后都会更新情景记忆。情景记忆模块由两个部分组成:注意力机制和记忆更新机制。注意力机制用于产生上下文向量c^{t}和通过问题q和前一次的情景记忆m^{t-1}推断的相关性,其中c^{t}是对传入t的相关输入的总结。记忆更新机制是基于上下文向量c^{t}和前一次的情景记忆m^{t-1},生成新的情景记忆m^{t}

回答模块(Answer Module):回答模块是根据qm^{t}来生成模型的预测答案。对于一个简单的答案,比如一个单词,使用一个softmax激活的线性层即可。对于复杂答案,比如一个句子,可以用RNN来解码问题q和情景记忆m^{t}链接a=[q;m^{t}]至单词的有序排列。训练过程可以采用交叉熵。

2. 改进的动态记忆网络DMN+(Improved Dynamic Memory Networks: DMN+)

作者提出的模型主要修改了两个地方:一个是输入表示,另一个是注意力机制和输入更新。文本问答(Text QA)和视觉问答(VQA)不同之处仅在于输入表示。

(1)文本问答的输入模块(Input Module for Text QA)

在DMN中,GRU被用于处理文本中的单词,提取句子的表示。但是,该方法在有支持信息(supporting facts)的bAbI-1k数据集上表现较好,在没有支持信息(supporting facts)的bAbI-10k数据集上表现很差。作者推测有两个原因:首先,GRU只允许句子前面的句子有上下文,但不允许后面的有。这组织了信息在后面句子中的传播。其次,支持语句可能在词级相距过于遥远,使得这些遥远的句子无法通过词级GRU相互作用。

输入融合层(Input Fusion Layer):在DMN+中,我们用两个部分替代了单个GRU。第一个是句子阅读器(sentence reader),用于将单词编码为句子嵌入;第二部分是输入融合层(input fusion layer),能够进行句子间的交互,它类似于分层神经自编码器结构。输入融合层采用双向GRU(bi-directional GRU),它能实现信息在句子前后的传递。因为梯度不需要借助句子间的单词进行传播,融合层也能实现远距离的支持语句(supporting sentences)进行更多的直接交互。

下图的Fig 2表示了一个输入模块,句子阅读器(sentence reader)使用了位置编码器(positional encoder),输入融合曾使用了双向GRU,每一个句子编码f_{i}是使用词特征[w_{1}^{i}, ......, w_{M_{i}}^{i}]的编码输出,其中M_{i}是句子的长度:

句子阅读器(sentence reader)能够使用多种编码机制。这里作者使用了位置编码。由于GRU和LSTM需要更多的计算资源而且在大量任务情况下(比如重构原始语句)容易过拟合,因此没有使用。

对于位置编码机制,句子表示可以利用公式:f_{i}=\sum_{M}^{j=1} l_{j}\circ w_{j}^{i}生成,其中\circ表示矩阵点乘,l_{j}是一个列向量,且l_{jd}=\left ( 1-j/M \right )-(d/D)(1-2j/M),d是嵌入的索引,D是嵌入的维度。

对于输入融合层,采用双向GRU,输入facts并能实现他们之间的信息交换:

(2)VQA的输入模块(Input Module for VQA)

将DMN引入视觉问答,作者提出一种新的图像输入模块。该模块将输入图像分成很多小的局部区域,每个小区域都等价于文本中的一个句子。该模块由3部分组成,如下图所示,局部区域特征提取,视觉特征嵌入和输入融合层:

局部区域特征提取(Local region feature extraction):基于VGG19来提取图像特征。首先将输入图像resize成448*448的,然后拿出最后一个池化层,此时维度为d=512*14*14,池化层将图像分成了14*14的小格网,因此原图对应196个局部区域,每个区域的维度是512。

视觉特征嵌入(Visual feature embedding):VQA涉及到图像和文本,作者添加了一个tanh激活的线性层,映射局部区域向量到问题向量q使用的文本特征空间。

输入融合层(Input fusion layer):前面做的局部特征提取,并没有提取全局特征,缺乏全局特征的表示能力是有限的,为了解决这个问题,作者加入了一个输入融合层,首先,生成一个输入的facts F,用蛇形遍历图像;然后在这些输入的facts F上用双向GRU来生成一个全局感知的输入facts,双向GRU能够使得信息在图像相邻的网格中传递信息。

(3)情景记忆模块(The Episodic Memory Module)

情景记忆模块如下图所示:

将注意力集中在这些facts的一个子集上以从输入的facts中检索信息。通过结合单个标量注意力门g_{i}^{t}(attention gate)和传递t中的每一个fact来实现该注意力。它的计算是通过fact和问题表示与情景记忆状态之间的交互实现:

注意力机制:

一旦有了注意力门g_{i}^{t},我们就能使用注意力机制来提取上下文向量c^{t}。我们主要关注两种类型的注意力,软注意力和GRU注意力。

软注意力(Soft attention):软注意力生成上下文向量c^{t},是通过向量\underset{F}{\rightarrow }排序列表相应的注意力门g_{i}^{t}加权求和得到的,即:c^{t}=\sum_{i=1}^{N}g_{i}^{t} f_{i}。这种方法有两个优点:第一,易于计算;第二,如果softmax激活是尖峰的,它可以通过仅为上下文向量选择一个fact来近似一个硬注意函数,且它仍然是可微的。

GRU注意力(Attention based GRU):对于复杂的查询,我们希望注意力对输入fact的位置和次序更敏感,在这种情况下RNN是更好的选择。我们提出了一个修改的GRU结构,它从注意力机制中嵌入信息。公式(1)中的更新门u_{i}决定了隐含状态的每一个维度的多少来保持,和多少来更新输入的x_{i}。因为u_{i}的计算只涉及当前的输入和前一步的隐含状态,它缺乏来自前一个情景记忆或者问题的知识。

因此再公式(4)中用注意力门g_{i}^{t}来取代更新门u_{i},GRU就可以用注意力来更新它的每一次状态,该过程可以用下图来描述:

用公式表达也就是:

需要重点考虑的是g_{i}^{t}是一个标量,它是softmax激活产生的,与向量u_{i}相反,它是由sigmoid激活产生的。这样使得我们更容易对注意力门如何激活输入进行可视化,结果如下图所示:

情景记忆更新(Episode Memory Updates)

通过每一次的注意力机制,我们都希望能够根据上下文向量c^{t}更新场景记忆m^{t-1},生成m^{t}。在DMN中,将初始隐含状态的GRU设置为问题向量q即为了实现该目的。对于传递t,情景记忆可以按照下式计算:

有文献建议,每次传递使用不同的权值更新场景记忆会更好一些。当模型只包含一组权值时,对于所有经过输入的情景,它被称为tied model,如Table 1中的“Mem Weights”行所示:

对于情景记忆的更新,作者使用的是ReLU层,计算新的情景记忆状态可以通过下式:

最后的记忆模块输出会传递到答案模块,这里和DMN的答案模块是一致的。

3. 数据集

实验使用的数据集有3个:

bAbI-10k:该合成数据集包含20种任务,每一个样本包含facts,问题,和答案,还有提示答案的supporting facts。

DAQUAR-ALL visual dataset:The DAtaset for QUestion Answering on Real-world images (DAQUAR)包含795张训练图像和654张测试图像,6795个训练问题和5673个测试问题。

VQA:VQA 1.0数据集

4. 实验

(1)模型分析

这里作者进行了DMN及其模型变种的分析,ODMN为原始的DMN模型;DMN2是用输入融合层取代输入模块;DMN3在DMN2的基础上,用GRU注意力取代了软注意力;DMN+是在DMN3的基础上,在每次传递的过程中,使用线性层ReLU和一组权值来更新记忆。所有的实验结果如Table 1中所示。

(2)基于bAbI-10k数据集的比较

对于该数据集,优化器采用Adam,学习率为0.001,batch size为128,epoch设置为256,如果验证的loss在20个epoch内不改变,则提前结束模型。词嵌入初始化采用随机均匀分布,范围为[-\sqrt{3}, \sqrt{3}],其余的权值初始化都采用Xavier。hidden size都设置为d=80,所有权值采用l2范数约束。对于dropout来说,保留输入的90%。场景记忆模块中,一共传递3次。不同task及其结果如下图所示:

 

(3)基于VQA的数据集比较

优化器用Adam,学习率0.003,batch size为100,一共256个epochs,如果验证loss在10个epochs内没有改变,则提前结束训练。权值初始化采用随机均匀分布,范围[-0.08 0.08]。hidden size设置为d=512,dropout=0.5。结果主要展示在Fig 6中。

三、实验介绍

本文主要参考代码:https://github.com/DeepRNN/visual_question_answering

所有文件结构为:

|------ base_model.py
|------ config.py
|------ dataset.py
|------ episodic_memory.py
|------ main.py
|------ model.py
|------ vgg16_no_fc.npy                    # 需要额外下载
|------ utils
        |------ vqa
                |------ vqa.py
                |------ vqaEval.py
        |------ vocabulary.py
        |------ misc.py
        |------ nn.py
|------ train                              # 需要额外下载
        |------ images
                |------ image01.jpg
                |------ image01.jpg
                |------ ......
        |------ mscoco_train2014_annotations.json
        |------ OpenEnded_mscoco_train2014_questions.json
|------ val                                # 需要额外下载
        |------ images
                |------ image01.jpg
                |------ image01.jpg
                |------ ......
        |------ mscoco_val2014_annotations.json
        |------ OpenEnded_mscoco_val2014_questions.json
|------ test
        |------ images

1. 实验环境

我的运行环境如下:

python 3.7

GPU: GTX 1050TI 4G

tensorflow 1.14

numpy 1.16.2

opencv 3.4.1

Natural Language Toolkit (NLTK) 3.4

Pandas 0.24.2

Matplotlib 3.0.3

tqdm 4.31.1

2. 数据集准备

(1)图像数据

训练的图像数据:http://images.cocodataset.org/zips/train2014.zip

验证的图像数据:http://images.cocodataset.org/zips/val2014.zip

测试的图像数据:http://images.cocodataset.org/zips/test2014.zip

以上链接可以直接打开网页或者迅雷下载,下载好之后解压,放到相应文件夹下的'images/'下

(2)文本数据

打开VQA的官方页面:https://visualqa.org/vqa_v1_download.html

然后找到下面的5个标注进行下载即可:

(3)VGG数据

作者给出的预训练的VGG16模型的下载链接点不开,所以通过其他渠道找到了该文件,我把它上传到自己的网盘上了:链接:https://pan.baidu.com/s/1jPzXKZIXbNnknT7Nubh3yw  提取码:nms0 

另外还从其他地方找到了resnet-101,提取图像特征也可以考虑用这个模型,直接用迅雷或者页面就可以进行下载,下载链接:http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz

3. 代码

按照前面的文件结构准备需要的训练数据,就可以开始编写代码进行实验部分,下面会逐一介绍实验中用到的文件:

(1)超参数config.py文件

改文件主要定义了实验的超参数,由于参数比较多,因此直接用类进行封装,类的初始化不需要传入任何参数,如果需要对超参数进行修改的话,可以直接在这个文件里进行操作:

class Config(object):
    """ Wrapper class for various (hyper)parameters. """
    def __init__(self):
        # 模型结构
        self.cnn = 'vgg16'               # 'vgg16' 或 'resnet50'
        self.max_question_length = 30    
        self.dim_embedding = 512        
        self.num_gru_units = 512         
        self.memory_step = 3             
        self.memory_update = 'relu'      # 'gru' 或 'relu'
        self.attention = 'gru'           # 'gru' 或 'soft',消融实验的时候可以设置为soft
        self.tie_memory_weight = False
        self.question_encoding = 'gru'   # 'gru' 或 'positional'
        self.embed_fact = False

        # 权值初始化和正则化
        self.fc_kernel_initializer_scale = 0.08
        self.fc_kernel_regularizer_scale = 1e-6
        self.fc_activity_regularizer_scale = 0.0
        self.conv_kernel_regularizer_scale = 1e-6
        self.conv_activity_regularizer_scale = 0.0
        self.fc_drop_rate = 0.5
        self.gru_drop_rate = 0.3

        # 优化
        self.num_epochs = 100
        self.batch_size = 4
        self.optimizer = 'Adam'          # 'Adam', 'RMSProp', 'Momentum' or 'SGD'
        self.initial_learning_rate = 0.0001
        self.learning_rate_decay_factor = 1.0
        self.num_steps_per_decay = 10000
        self.clip_gradients = 10.0
        self.momentum = 0.0
        self.use_nesterov = True
        self.decay = 0.9
        self.centered = True
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.epsilon = 1e-5

        # 存储器
        self.save_period = 1000
        self.save_dir = './models/'
        self.summary_dir = './summary/'

        # 词汇表
        self.vocabulary_file = './vocabulary.csv'

        # 训练
        self.train_image_dir = './train/images/'
        self.train_question_file = './train/OpenEnded_mscoco_train2014_questions.json'
        self.train_answer_file = './train/mscoco_train2014_annotations.json'
        self.temp_train_annotation_file = './train/anns.csv'
        self.temp_train_data_file = './train/data.npy'

        # 评价
        self.eval_image_dir = './val/images/'
        self.eval_question_file = './val/OpenEnded_mscoco_val2014_questions.json'
        self.eval_answer_file = './val/mscoco_val2014_annotations.json'
        self.temp_eval_annotation_file = './val/anns.csv'
        self.temp_eval_data_file = './val/data.npy'
        self.eval_result_dir = './val/results/'
        self.eval_result_file = './val/results.json'
        self.save_eval_result_as_image = False

        # 测试
        self.test_image_dir = './test/images/'
        self.test_question_file = './test/questions.csv'
        self.temp_test_info_file = './test/info.csv'
        self.test_result_dir = './test/results/'
        self.test_result_file = './test/results.csv'

(2)utils文件夹下的文件

utils文件夹下包含一个vqa子文件夹和其他几个文件,现在逐一来看:

misc.py文件:改文件里面就定义了一个读取图像的类,该类有两个函数,一个是用于读取单张图像,另一个则是用于读取多个图像,下面来看一下该文件的代码:

import numpy as np
import cv2

class ImageLoader(object):
    def __init__(self, mean_file):
        self.bgr = True
        self.scale_shape = np.array([224, 224], np.int32)
        self.crop_shape = np.array([224, 224], np.int32)
        self.mean = np.load(mean_file).mean(1).mean(1)

    def load_image(self, image_file):
        """ Load and preprocess an image. """
        image = cv2.imread(image_file)

        if self.bgr:
            temp = image.swapaxes(0, 2)
            temp = temp[::-1]                    # 变成rgb
            image = temp.swapaxes(0, 2)

        image = cv2.resize(image, (self.scale_shape[0], self.scale_shape[1]))
        offset = (self.scale_shape - self.crop_shape) / 2
        offset = offset.astype(np.int32)
        image = image[offset[0]:offset[0]+self.crop_shape[0],
                      offset[1]:offset[1]+self.crop_shape[1]]
        image = image - self.mean
        return image

    def load_images(self, image_files):
        """ Load and preprocess a list of images. """
        images = []
        for image_file in image_files:
            images.append(self.load_image(image_file))
        images = np.array(images, np.float32)
        return images

nn.py文件:主要定义了网络中的一些操作,包括GRU,BN,FC,卷积等:

import tensorflow as tf
import tensorflow.contrib.layers as layers


class NN(object):
    def __init__(self, config):
        self.config = config
        self.is_train = True if config.phase == 'train' else False
        self.train_cnn = self.is_train and config.train_cnn
        self.prepare()

    def prepare(self):
        """ Setup the weight initalizers and regularizers. """
        config = self.config

        self.conv_kernel_initializer = layers.xavier_initializer()

        if self.train_cnn and config.conv_kernel_regularizer_scale > 0:
            self.conv_kernel_regularizer = layers.l2_regularizer(
                scale = config.conv_kernel_regularizer_scale)
        else:
            self.conv_kernel_regularizer = None

        if self.train_cnn and config.conv_activity_regularizer_scale > 0:
            self.conv_activity_regularizer = layers.l1_regularizer(
                scale = config.conv_activity_regularizer_scale)
        else:
            self.conv_activity_regularizer = None

        self.fc_kernel_initializer = tf.random_uniform_initializer(
            minval = -config.fc_kernel_initializer_scale,
            maxval = config.fc_kernel_initializer_scale)

        if self.is_train and config.fc_kernel_regularizer_scale > 0:
            self.fc_kernel_regularizer = layers.l2_regularizer(
                scale = config.fc_kernel_regularizer_scale)
        else:
            self.fc_kernel_regularizer = None

        if self.is_train and config.fc_activity_regularizer_scale > 0:
            self.fc_activity_regularizer = layers.l1_regularizer(
                scale = config.fc_activity_regularizer_scale)
        else:
            self.fc_activity_regularizer = None

    def conv2d(self,
               inputs,
               filters,
               kernel_size = (3, 3),
               strides = (1, 1),
               activation = tf.nn.relu,
               use_bias = True,
               name = None):
        """ 2D Convolution layer. """
        if activation is not None:
            activity_regularizer = self.conv_activity_regularizer
        else:
            activity_regularizer = None
        return tf.layers.conv2d(
            inputs = inputs,
            filters = filters,
            kernel_size = kernel_size,
            strides = strides,
            padding='same',
            activation = activation,
            use_bias = use_bias,
            trainable = self.train_cnn,
            kernel_initializer = self.conv_kernel_initializer,
            kernel_regularizer = self.conv_kernel_regularizer,
            activity_regularizer = activity_regularizer,
            name = name)

    def max_pool2d(self,
                   inputs,
                   pool_size = (2, 2),
                   strides = (2, 2),
                   name = None):
        """ 2D Pooling layer. """
        return tf.layers.max_pooling2d(
            inputs = inputs,
            pool_size = pool_size,
            strides = strides,
            padding='same',
            name = name)

    def dense(self,
              inputs,
              units,
              activation = tf.tanh,
              use_bias = True,
              name = None):
        """ Fully-connected layer. """
        if activation is not None:
            activity_regularizer = self.fc_activity_regularizer
        else:
            activity_regularizer = None
        return tf.layers.dense(
            inputs = inputs,
            units = units,
            activation = activation,
            use_bias = use_bias,
            trainable = self.is_train,
            kernel_initializer = self.fc_kernel_initializer,
            kernel_regularizer = self.fc_kernel_regularizer,
            activity_regularizer = activity_regularizer,
            name = name)

    def dropout(self,
                inputs,
                name = None):
        """ Dropout layer. """
        return tf.layers.dropout(
            inputs = inputs,
            rate = self.config.fc_drop_rate,
            training = self.is_train)

    def batch_norm(self,
                   inputs,
                   name = None):
        """ Batch normalization layer. """
        return tf.layers.batch_normalization(
            inputs = inputs,
            training = self.train_cnn,
            trainable = self.train_cnn,
            name = name
        )

    def gru(self):
        """ GRU layer. """
        gru = tf.nn.rnn_cell.GRUCell(
            num_units = self.config.num_gru_units,
            kernel_initializer = self.fc_kernel_initializer)
        if self.is_train:
            gru = tf.nn.rnn_cell.DropoutWrapper(
                gru,
                input_keep_prob = 1.0 - self.config.gru_drop_rate,
                output_keep_prob = 1.0 - self.config.gru_drop_rate,
                state_keep_prob = 1.0 - self.config.gru_drop_rate)
        return gru

vocabulary.py文件:该文件用于生成词汇表,并保存到csv文件中:

import os
import numpy as np
import pandas as pd
from nltk.tokenize import word_tokenize

class Vocabulary(object):
    def __init__(self, save_file = None):
        self.words = []
        self.word2idx = {}
        self.size = 0
        self.word_counts = {}
        self.word_frequencies = []
        if save_file is not None:
            self.load(save_file)
        else:
            self.add_words(["<unknown>"])

    def add_words(self, words):
        """ Add new words to the vocabulary. """
        for w in words:
            if w not in self.word2idx.keys():
                self.words.append(w)
                self.word2idx[w] = self.size
                self.size += 1
            self.word_counts[w] = self.word_counts.get(w, 0) + 1

    def compute_frequency(self):
        """ Compute the frequency of each word. """
        self.word_frequencies = []
        for w in self.words:
            self.word_frequencies.append(self.word_counts[w])
        self.word_frequencies = np.array(self.word_frequencies, np.float32)
        self.word_frequencies /= np.sum(self.word_frequencies)
        self.word_frequencies = np.log(self.word_frequencies)
        self.word_frequencies -= np.max(self.word_frequencies)

    def word_to_idx(self, word):
        """ Translate a word into its index. """
        return self.word2idx[word] if word in self.word2idx.keys() else 0

    def process_sentence(self, sentence):
        """ Tokenize a sentence, and translate each token into its index
            in the vocabulary. """
        words = word_tokenize(sentence.lower())
        word_idxs = [self.word_to_idx(w) for w in words]
        return word_idxs

    def save(self, save_file):
        """ Save the vocabulary to a file. """
        data = pd.DataFrame({'word': self.words,
                             'index': list(range(self.size)),
                             'frequency': self.word_frequencies})
        data.to_csv(save_file)

    def load(self, save_file):
        """ Load the vocabulary from a file. """
        assert os.path.exists(save_file)
        data = pd.read_csv(save_file)
        self.words = data['word'].values
        self.size = len(self.words)
        self.word2idx = {self.words[i]:i for i in range(self.size)}
        self.word_frequencies = data['frequency'].values

最后是vqa文件夹下的两个文件:vqa.py和vqaEval.py。:

"vqa/vqa.py"文件的主要内容:

__author__ = 'aagrawal'
__version__ = '0.9'

# Interface for accessing the VQA dataset.

# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).

# The following functions are defined:
#  VQA        - VQA class that loads VQA annotation file and prepares data structures.
#  getQuesIds - Get question ids that satisfy given filter conditions.
#  getImgIds  - Get image ids that satisfy given filter conditions.
#  loadQA     - Load questions and answers with the specified question ids.
#  showQA     - Display the specified questions and answers.
#  loadRes    - Load result file and create result object.

# Help on each function can be accessed by: "help(COCO.function)"

import json
import datetime
import copy
from tqdm import tqdm
from nltk.tokenize import word_tokenize

class VQA:
        def __init__(self, annotation_file=None, question_file=None):
                """
                Constructor of VQA helper class for reading and visualizing questions and answers.
                :param annotation_file (str): location of VQA annotation file
                :return:
                """
                # load dataset
                self.dataset = {}
                self.questions = {}
                self.qa = {}
                self.qqa = {}
                self.imgToQA = {}
                self.max_ques_len = 0
                if not annotation_file == None and not question_file == None:
                        print('loading VQA annotations and questions into memory...')
                        time_t = datetime.datetime.utcnow()
                        dataset = json.load(open(annotation_file, 'r'))
                        questions = json.load(open(question_file, 'r'))
                        print(datetime.datetime.utcnow() - time_t)
                        self.dataset = dataset
                        self.questions = questions
                        self.process_dataset()
                        self.createIndex()

        def createIndex(self):
                # create index
                print('creating index...')
                imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
                qa =  {ann['question_id']: [] for ann in self.dataset['annotations']}
                qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
                max_ques_len = 0
                for ann in self.dataset['annotations']:
                        imgToQA[ann['image_id']] += [ann]
                        qa[ann['question_id']] = ann
                # print(qa)
                for ques in self.questions['questions']:
                        qqa[ques['question_id']] = ques
                        max_ques_len = max(max_ques_len, len(word_tokenize(ques['question'])))

                # create class members
                self.qa = qa
                self.qqa = qqa
                self.imgToQA = imgToQA
                self.max_ques_len = max_ques_len
                # print("11111111111")

        def info(self):
                """
                Print information about the VQA annotation file.
                :return:
                """
                for key, value in list(self.datset['info'].items()):
                        print('%s: %s'%(key, value))

        def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
                """
                Get question ids that satisfy given filter conditions. default skips that filter
                :param         imgIds    (int array)   : get question ids for given imgs
                                quesTypes (str array)   : get question ids for given question types
                                ansTypes  (str array)   : get question ids for given answer types
                :return:    ids   (int array)   : integer array of question ids
                """
                imgIds           = imgIds    if type(imgIds)    == list else [imgIds]
                quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
                ansTypes  = ansTypes  if type(ansTypes)  == list else [ansTypes]

                if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
                        anns = self.dataset['annotations']
                else:
                        if not len(imgIds) == 0:
                                anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
                        else:
                                anns = self.dataset['annotations']
                        anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
                        anns = anns if len(ansTypes)  == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
                ids = [ann['question_id'] for ann in anns]
                return ids

        def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
                """
                Get image ids that satisfy given filter conditions. default skips that filter
                :param quesIds   (int array)   : get image ids for given question ids
                quesTypes (str array)   : get image ids for given question types
                ansTypes  (str array)   : get image ids for given answer types
                :return: ids     (int array)   : integer array of image ids
                """
                quesIds   = quesIds   if type(quesIds)   == list else [quesIds]
                quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
                ansTypes  = ansTypes  if type(ansTypes)  == list else [ansTypes]

                if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
                        anns = self.dataset['annotations']
                else:
                        if not len(quesIds) == 0:
                                anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
                        else:
                                anns = self.dataset['annotations']
                        anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
                        anns = anns if len(ansTypes)  == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
                ids = [ann['image_id'] for ann in anns]
                return ids

        def loadQA(self, ids=[]):
                """
                Load questions and answers with the specified question ids.
                :param ids (int array)       : integer ids specifying question ids
                :return: qa (object array)   : loaded qa objects
                """
                if type(ids) == list:
                        return [self.qa[id] for id in ids]
                elif type(ids) == int:
                        return [self.qa[ids]]

        def showQA(self, anns):
                """
                Display the specified annotations.
                :param anns (array of object): annotations to display
                :return: None
                """
                if len(anns) == 0:
                        return 0
                for ann in anns:
                        quesId = ann['question_id']
                        print("Question: %s" %(self.qqa[quesId]['question']))
                        for ans in ann['answers']:
                                print("Answer %d: %s" %(ans['answer_id'], ans['answer']))

        def loadRes(self, resFile, quesFile):
                """
                Load result file and return a result object.
                :param   resFile (str)     : file name of result file
                :return: res (obj)         : result api object
                """
                res = VQA()
                res.questions = json.load(open(quesFile))
                res.dataset['info'] = copy.deepcopy(self.questions['info'])
                res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
                res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
                res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
                res.dataset['license'] = copy.deepcopy(self.questions['license'])

                print('Loading and preparing results...     ')
                time_t = datetime.datetime.utcnow()
                anns    = json.load(open(resFile))
                assert type(anns) == list, 'results is not an array of objects'
                annsQuesIds = [ann['question_id'] for ann in anns]
                assert set(annsQuesIds) == set(self.getQuesIds()), \
                'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
                for ann in anns:
                        quesId                              = ann['question_id']
                        if res.dataset['task_type'] == 'Multiple Choice':
                                assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
                        qaAnn                = self.qa[quesId]
                        ann['image_id']      = qaAnn['image_id']
                        ann['question_type'] = qaAnn['question_type']
                        ann['answer_type']   = qaAnn['answer_type']
                print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))

                res.dataset['annotations'] = anns
                res.createIndex()
                return res

        def process_dataset(self):
                for ann in self.dataset['annotations']:
                    count = {}
                    for ans in ann['answers']:
                        ans['answer'] = ans['answer'].lower()                   # 将每个答案中的大写变成小写
                        count[ans['answer']] = count.get(ans['answer'], 0) + 1  # 统计每个答案出现的数量

                    sorted_ans = sorted(list(count.items()),
                                        key=lambda x: x[1],
                                        reverse=True)                           # 按答案出现次数排序
                    best_ans, best_ans_count = sorted_ans[0]                    # 记录出现次数最多的答案和次数
                    ann['best_answer'] = best_ans
                    ann['best_answer_count'] = best_ans_count

                for ques in self.questions['questions']:
                    q = ques['question']
                    q = q.replace('?', '')                                      # 用空格将问号取代掉
                    q = q.lower()
                    ques['question'] = q

        def filter_by_ques_len(self, max_ques_len):
                print("Filtering the questions by length...")
                keep_ques = {}
                for ques in tqdm(self.questions['questions']):
                    if len(word_tokenize(ques['question'])) <= max_ques_len:
                        keep_ques[ques['question_id']] = \
                            keep_ques.get(ques['question_id'], 0) + 1

                self.dataset['annotations'] = \
                    [ann for ann in self.dataset['annotations'] \
                    if keep_ques.get(ann['question_id'],0)>0]
                self.questions['questions'] = \
                    [ques for ques in self.questions['questions'] \
                    if keep_ques.get(ques['question_id'],0)>0]

                self.createIndex()

        def filter_by_ans_len(self, max_ans_len, min_freq=5):
                print("Filtering the answers by length...")
                keep_ques = {}
                for ann in tqdm(self.dataset['annotations']):
                    if len(word_tokenize(ann['best_answer'])) <= max_ans_len \
                        and ann['best_answer_count']>=min_freq:
                        keep_ques[ann['question_id']] = \
                            keep_ques.get(ann['question_id'], 0) + 1

                self.dataset['annotations'] = \
                    [ann for ann in self.dataset['annotations'] \
                    if keep_ques.get(ann['question_id'],0)>0]
                self.questions['questions'] = \
                    [ques for ques in self.questions['questions'] \
                    if keep_ques.get(ques['question_id'],0)>0]

                self.createIndex()

"vqa/vqaEval.py"文件的主要内容:

import sys
import re
from tqdm import tqdm

class VQAEval:
        def __init__(self, vqa, vqaRes, n=2):
                self.n            = n
                self.accuracy     = {}
                self.evalQA       = {}
                self.evalQuesType = {}
                self.evalAnsType  = {}
                self.vqa          = vqa
                self.vqaRes       = vqaRes
                self.params       = {'question_id': vqa.getQuesIds()}
                self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
                                     "couldn'tve": "couldn’t’ve", "couldnt’ve": "couldn’t’ve", "didnt": "didn’t", "doesnt": "doesn’t", "dont": "don’t", "hadnt": "hadn’t", \
                                     "hadnt’ve": "hadn’t’ve", "hadn'tve": "hadn’t’ve", "hasnt": "hasn’t", "havent": "haven’t", "hed": "he’d", "hed’ve": "he’d’ve", \
                                     "he’dve": "he’d’ve", "hes": "he’s", "howd": "how’d", "howll": "how’ll", "hows": "how’s", "Id’ve": "I’d’ve", "I’dve": "I’d’ve", \
                                     "Im": "I’m", "Ive": "I’ve", "isnt": "isn’t", "itd": "it’d", "itd’ve": "it’d’ve", "it’dve": "it’d’ve", "itll": "it’ll", "let’s": "let’s", \
                                     "maam": "ma’am", "mightnt": "mightn’t", "mightnt’ve": "mightn’t’ve", "mightn’tve": "mightn’t’ve", "mightve": "might’ve", \
                                     "mustnt": "mustn’t", "mustve": "must’ve", "neednt": "needn’t", "notve": "not’ve", "oclock": "o’clock", "oughtnt": "oughtn’t", \
                                     "ow’s’at": "’ow’s’at", "’ows’at": "’ow’s’at", "’ow’sat": "’ow’s’at", "shant": "shan’t", "shed’ve": "she’d’ve", "she’dve": "she’d’ve", \
                                     "she’s": "she’s", "shouldve": "should’ve", "shouldnt": "shouldn’t", "shouldnt’ve": "shouldn’t’ve", "shouldn’tve": "shouldn’t’ve", \
                                     "somebody’d": "somebodyd", "somebodyd’ve": "somebody’d’ve", "somebody’dve": "somebody’d’ve", "somebodyll": "somebody’ll", \
                                     "somebodys": "somebody’s", "someoned": "someone’d", "someoned’ve": "someone’d’ve", "someone’dve": "someone’d’ve", \
                                     "someonell": "someone’ll", "someones": "someone’s", "somethingd": "something’d", "somethingd’ve": "something’d’ve", \
                                     "something’dve": "something’d’ve", "somethingll": "something’ll", "thats": "that’s", "thered": "there’d", "thered’ve": "there’d’ve", \
                                     "there’dve": "there’d’ve", "therere": "there’re", "theres": "there’s", "theyd": "they’d", "theyd’ve": "they’d’ve", \
                                     "they’dve": "they’d’ve", "theyll": "they’ll", "theyre": "they’re", "theyve": "they’ve", "twas": "’twas", "wasnt": "wasn’t", \
                                     "wed’ve": "we’d’ve", "we’dve": "we’d’ve", "weve": "we've", "werent": "weren’t", "whatll": "what’ll", "whatre": "what’re", \
                                     "whats": "what’s", "whatve": "what’ve", "whens": "when’s", "whered": "where’d", "wheres": "where's", "whereve": "where’ve", \
                                     "whod": "who’d", "whod’ve": "who’d’ve", "who’dve": "who’d’ve", "wholl": "who’ll", "whos": "who’s", "whove": "who've", "whyll": "why’ll", \
                                     "whyre": "why’re", "whys": "why’s", "wont": "won’t", "wouldve": "would’ve", "wouldnt": "wouldn’t", "wouldnt’ve": "wouldn’t’ve", \
                                     "wouldn’tve": "wouldn’t’ve", "yall": "y’all", "yall’ll": "y’all’ll", "y’allll": "y’all’ll", "yall’d’ve": "y’all’d’ve", \
                                     "y’alld’ve": "y’all’d’ve", "y’all’dve": "y’all’d’ve", "youd": "you’d", "youd’ve": "you’d’ve", "you’dve": "you’d’ve", \
                                     "youll": "you’ll", "youre": "you’re", "youve": "you’ve"}
                self.manualMap    = {'none': '0',
                                     'zero': '0',
                                     'one': '1',
                                     'two': '2',
                                     'three': '3',
                                     'four': '4',
                                     'five': '5',
                                     'six': '6',
                                     'seven': '7',
                                     'eight': '8',
                                     'nine': '9',
                                     'ten': '10'}
                self.articles     = ['a', 'an', 'the']
 

                self.periodStrip  = re.compile("(?!<=\d)(\.)(?!\d)")
                self.commaStrip   = re.compile("(\d)(\,)(\d)")
                self.punct        = [';', r"/", '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', '>', '<', '@', '`', ',', '?', '!']

        
        def evaluate(self, quesIds=None):
                if quesIds == None:
                        quesIds = [quesId for quesId in self.params['question_id']]
                gts = {}
                res = {}
                for quesId in quesIds:
                        gts[quesId] = self.vqa.qa[quesId]
                        res[quesId] = self.vqaRes.qa[quesId]
                
                # =================================================
                # Compute accuracy
                # =================================================
                accQA       = []
                accQuesType = {}
                accAnsType  = {}
                print("computing accuracy")
                step = 0
                for quesId in tqdm(quesIds):
                        resAns      = res[quesId]['answer']
                        resAns      = resAns.replace('\n', ' ')
                        resAns      = resAns.replace('\t', ' ')
                        resAns      = resAns.strip()
                        resAns      = self.processPunctuation(resAns)
                        resAns      = self.processDigitArticle(resAns)
                        gtAcc  = []
                        gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
                        if len(set(gtAnswers)) > 1: 
                                for ansDic in gts[quesId]['answers']:
                                        ansDic['answer'] = self.processPunctuation(ansDic['answer'])
                        for gtAnsDatum in gts[quesId]['answers']:
                                otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
                                matchingAns = [item for item in otherGTAns if item['answer']==resAns]
                                acc = min(1, float(len(matchingAns))/3)
                                gtAcc.append(acc)
                        quesType    = gts[quesId]['question_type']
                        ansType     = gts[quesId]['answer_type']
                        avgGTAcc = float(sum(gtAcc))/len(gtAcc)
                        accQA.append(avgGTAcc)
                        if quesType not in accQuesType:
                                accQuesType[quesType] = []
                        accQuesType[quesType].append(avgGTAcc)
                        if ansType not in accAnsType:
                                accAnsType[ansType] = []
                        accAnsType[ansType].append(avgGTAcc)
                        self.setEvalQA(quesId, avgGTAcc)
                        self.setEvalQuesType(quesId, quesType, avgGTAcc)
                        self.setEvalAnsType(quesId, ansType, avgGTAcc)
                        step = step + 1

                self.setAccuracy(accQA, accQuesType, accAnsType)
                print("Done computing accuracy")

                self.showAccuracy(accQA, accQuesType, accAnsType)
        
        def processPunctuation(self, inText):
                outText = inText
                for p in self.punct:
                        if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
                                outText = outText.replace(p, '')
                        else:
                                outText = outText.replace(p, ' ')        
                outText = self.periodStrip.sub("", outText, re.UNICODE)
                return outText
        
        def processDigitArticle(self, inText):
                outText = []
                tempText = inText.lower().split()
                for word in tempText:
                        word = self.manualMap.setdefault(word, word)
                        if word not in self.articles:
                                outText.append(word)
                        else:
                                pass
                for wordId, word in enumerate(outText):
                        if word in self.contractions: 
                                outText[wordId] = self.contractions[word]
                outText = ' '.join(outText)
                return outText

        def setAccuracy(self, accQA, accQuesType, accAnsType):
                self.accuracy['overall']         = round(100*float(sum(accQA))/len(accQA), self.n)
                self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
                self.accuracy['perAnswerType']   = {ansType:  round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}

        def showAccuracy(self, accQA, accQuesType, accAnsType):
                print("Overall accurarcy = %f" %(self.accuracy['overall']))

                print("Accuracy per question type:")
                for quesType in accQuesType:
                    print("quesType: %s   accuracy = %f" %(quesType, self.accuracy['perQuestionType'][quesType]))

                print("Accuracy per answer type:")
                for ansType in accAnsType:
                    print("ansType: %s   accuracy = %f" %(ansType, self.accuracy['perAnswerType'][ansType]))

        def setEvalQA(self, quesId, acc):
                self.evalQA[quesId] = round(100*acc, self.n)

        def setEvalQuesType(self, quesId, quesType, acc):
                if quesType not in self.evalQuesType:
                        self.evalQuesType[quesType] = {}
                self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
        
        def setEvalAnsType(self, quesId, ansType, acc):
                if ansType not in self.evalAnsType:
                        self.evalAnsType[ansType] = {}
                self.evalAnsType[ansType][quesId] = round(100*acc, self.n)

        def updateProgress(self, progress):
                barLength = 20
                status = ""
                if isinstance(progress, int):
                        progress = float(progress)
                if not isinstance(progress, float):
                        progress = 0
                        status = "error: progress var must be float\r\n"
                if progress < 0:
                        progress = 0
                        status = "Halt...\r\n"
                if progress >= 1:
                        progress = 1
                        status = "Done...\r\n"
                block = int(round(barLength*progress))
                text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
                print(text)

(3)数据的读取加载文件dataset.py

该文件主要是将(图像,问题,答案)三种数据整合起来,如果是第一次做VQA,会觉得整合数据还是非常有难度的,所以就来学习一下这个文件是如何写的:

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from nltk.tokenize import word_tokenize            # 如果word_tokenize报错了就重新安一下

from utils.vocabulary import Vocabulary
from utils.vqa.vqa import VQA

# 建立数据类
class DataSet(object):
    def __init__(self,
                 image_files,
                 question_word_idxs,
                 question_lens,
                 question_ids,
                 batch_size,
                 answer_idxs = None,
                 is_train = False,
                 shuffle = False):
        self.image_files = np.array(image_files)
        self.question_word_idxs = np.array(question_word_idxs)
        self.question_lens = np.array(question_lens)
        self.question_ids = np.array(question_ids)
        self.answer_idxs = np.array(answer_idxs)
        self.batch_size = batch_size
        self.is_train = is_train
        self.shuffle = shuffle
        self.setup()                # 初始化传完参数建立数据集

    def setup(self):
        """ Setup the dataset. """
        self.count = len(self.question_ids)            # 问题计数
        self.num_batches = int(np.ceil(self.count * 1.0 / self.batch_size)) 
        self.fake_count = self.num_batches * self.batch_size - self.count
        self.idxs = list(range(self.count))
        self.reset()                # 重建

    def reset(self):
        """ Reset the dataset. """
        self.current_idx = 0
        if self.shuffle:
            np.random.shuffle(self.idxs)

    def next_batch(self):
        """ Fetch the next batch. """
        assert self.has_next_batch()

        if self.has_full_next_batch():
            start, end = self.current_idx, self.current_idx + self.batch_size
            current_idxs = self.idxs[start:end]
        else:
            start, end = self.current_idx, self.count
            current_idxs = self.idxs[start:end]
            current_idxs += list(np.random.choice(self.count, self.fake_count))

        image_files = self.image_files[current_idxs]
        question_word_idxs = self.question_word_idxs[current_idxs]
        question_lens = self.question_lens[current_idxs]

        if self.is_train:
            answer_idxs = self.answer_idxs[current_idxs]
            self.current_idx += self.batch_size
            return image_files, question_word_idxs, question_lens, answer_idxs
        else:
            self.current_idx += self.batch_size
            return image_files, question_word_idxs, question_lens

    def has_next_batch(self):
        """ Determine whether there is a batch left. """
        return self.current_idx < self.count

    def has_full_next_batch(self):
        """ Determine whether there is a full batch left. """
        return self.current_idx + self.batch_size <= self.count


# 准备训练数据
def prepare_train_data(config):
    """ Prepare the data for training the model. """
    vqa = VQA(config.train_answer_file, config.train_question_file)
    vqa.filter_by_ques_len(config.max_question_length)
    vqa.filter_by_ans_len(1)

    print("Reading the questions and answers...")
    annotations = process_vqa(vqa,
                              'COCO_train2014',
                              config.train_image_dir,
                              config.temp_train_annotation_file)

    image_files = annotations['image_file'].values
    questions = annotations['question'].values
    question_ids = annotations['question_id'].values
    answers = annotations['answer'].values
    print("Questions and answers read.")
    print("Number of questions = %d" %(len(question_ids)))

    print("Building the vocabulary...")
    vocabulary = Vocabulary()
    if not os.path.exists(config.vocabulary_file):
        for question in tqdm(questions):
            vocabulary.add_words(word_tokenize(question))
        for answer in tqdm(answers):
            vocabulary.add_words(word_tokenize(answer))
        vocabulary.compute_frequency()
        vocabulary.save(config.vocabulary_file)
    else:
        vocabulary.load(config.vocabulary_file)
    print("Vocabulary built.")
    print("Number of words = %d" %(vocabulary.size))
    config.vocabulary_size = vocabulary.size

    print("Processing the questions and answers...")
    if not os.path.exists(config.temp_train_data_file):
        question_word_idxs, question_lens = process_questions(questions,
                                                              vocabulary,
                                                              config)
        answer_idxs = process_answers(answers, vocabulary)
        data = {'question_word_idxs': question_word_idxs,
                'question_lens': question_lens,
                'answer_idxs': answer_idxs}
        np.save(config.temp_train_data_file, data)
    else:
        data = np.load(config.temp_train_data_file).item()
        question_word_idxs = data['question_word_idxs']
        question_lens = data['question_lens']
        answer_idxs = data['answer_idxs']
    print("Questions and answers processed.")

    print("Building the dataset...")
    dataset = DataSet(image_files,
                      question_word_idxs,
                      question_lens,
                      question_ids,
                      config.batch_size,
                      answer_idxs,
                      True,
                      True)
    print("Dataset built.")
    return dataset, config


# 准备评价数据
def prepare_eval_data(config):
    """ Prepare the data for evaluating the model. """
    vqa = VQA(config.eval_answer_file, config.eval_question_file)
    vqa.filter_by_ques_len(config.max_question_length)
    vqa.filter_by_ans_len(1)

    print("Reading the questions...")
    annotations = process_vqa(vqa,
                              'COCO_val2014',
                              config.eval_image_dir,
                              config.temp_eval_annotation_file)

    image_files = annotations['image_file'].values
    questions = annotations['question'].values
    question_ids = annotations['question_id'].values
    print("Questions read.")
    print("Number of questions = %d" %(len(question_ids)))

    print("Building the vocabulary...")
    if os.path.exists(config.vocabulary_file):
        vocabulary = Vocabulary(config.vocabulary_file)
    else:
        vocabulary = build_vocabulary(config)
    print("Vocabulary built.")
    print("Number of words = %d" %(vocabulary.size))
    config.vocabulary_size = vocabulary.size

    print("Processing the questions...")
    if not os.path.exists(config.temp_eval_data_file):
        question_word_idxs, question_lens = process_questions(questions,
                                                              vocabulary,
                                                              config)
        data = {'question_word_idxs': question_word_idxs,
                'question_lens': question_lens}
        np.save(config.temp_eval_data_file, data)
    else:
        data = np.load(config.temp_eval_data_file).item()
        question_word_idxs = data['question_word_idxs']
        question_lens = data['question_lens']
    print("Questions processed.")

    print("Building the dataset...")
    dataset = DataSet(image_files,
                      question_word_idxs,
                      question_lens,
                      question_ids,
                      config.batch_size)
    print("Dataset built.")
    return vqa, dataset, vocabulary, config


# 准备测试数据
def prepare_test_data(config):
    """ Prepare the data for testing the model. """
    print("Reading the questions...")
    annotations = pd.read_csv(config.test_question_file)
    images = annotations['image'].unique()
    image_files = [os.path.join(config.test_image_dir, f) for f in images]

    temp = pd.DataFrame({'image': images, 'image_file': image_files})
    annotations = pd.merge(annotations, temp)
    annotations.to_csv(config.temp_test_info_file)

    image_files = annotations['image_file'].values
    questions = annotations['question'].values
    question_ids = annotations['question_id'].values
    print("Questions read.")
    print("Number of questions = %d" %(len(question_ids)))

    print("Building the vocabulary...")
    if os.path.exists(config.vocabulary_file):
        vocabulary = Vocabulary(config.vocabulary_file)
    else:
        vocabulary = build_vocabulary(config)
    print("Vocabulary built.")
    print("Number of words = %d" %(vocabulary.size))
    config.vocabulary_size = vocabulary.size

    print("Processing the questions...")
    question_word_idxs, question_lens = process_questions(questions,
                                                          vocabulary,
                                                          config)
    print("Questions processed.")

    print("Building the dataset...")
    dataset = DataSet(image_files,
                      question_word_idxs,
                      question_lens,
                      question_ids,
                      config.batch_size)
    print("Dataset built.")
    return dataset, vocabulary, config


# 处理vqa
def process_vqa(vqa, label, image_dir, annotation_file):
    """ Build a temporary annotation file for training or evaluation. """
    question_ids = list(vqa.qa.keys())
    image_ids = [vqa.qa[k]['image_id'] for k in question_ids]
    image_files = [os.path.join(image_dir, label+"_000000"+("%06d" %k)+".jpg")
                   for k in image_ids]
    questions = [vqa.qqa[k]['question'] for k in question_ids]
    answers = [vqa.qa[k]['best_answer'] for k in question_ids]

    annotations = pd.DataFrame({'question_id': question_ids,
                                'image_id': image_ids,
                                'image_file': image_files,
                                'question': questions,
                                'answer': answers})
    annotations.to_csv(annotation_file)
    return annotations


# 处理问题
def process_questions(questions, vocabulary, config):
    """ Tokenize the questions and translate each token into its index \
        in the vocabulary, and get the number of tokens. """
    question_word_idxs = []
    question_lens = []
    for q in tqdm(questions):
        word_idxs = vocabulary.process_sentence(q)
        current_length = len(word_idxs)
        current_word_idxs = np.zeros((config.max_question_length), np.int32)
        current_word_idxs[:current_length] = np.array(word_idxs)
        question_word_idxs.append(current_word_idxs)
        question_lens.append(current_length)
    return np.array(question_word_idxs), np.array(question_lens)


# 处理答案
def process_answers(answers, vocabulary):
    """ Translate the answers into their indicies in the vocabulary. """
    answer_idxs = []
    for answer in tqdm(answers):
        answer_idxs.append(vocabulary.word_to_idx(word_tokenize(answer)[0]))
    return np.array(answer_idxs)


# 建立词汇表
def build_vocabulary(config):
    """ Build the vocabulary from the training data and save it to a file. """
    vqa = VQA(config.train_answer_file, config.train_question_file)
    vqa.filter_by_ques_len(config.max_question_length)
    vqa.filter_by_ans_len(1)

    question_ids = list(vqa.qa.keys())
    questions = [vqa.qqa[k]['question'] for k in question_ids]
    answers = [vqa.qa[k]['best_answer'] for k in question_ids]

    vocabulary = Vocabulary()
    for question in tqdm(questions):
        vocabulary.add_words(word_tokenize(question))
    for answer in tqdm(answers):
        vocabulary.add_words(word_tokenize(answer))
    vocabulary.compute_frequency()
    vocabulary.save(config.vocabulary_file)
    return vocabulary

(4)基本模型文件base_model.py

该文件主要的功能就是实现模型的存储,加载等:

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import pickle as pickle
from tqdm import tqdm
import json
import copy
import string

from utils.nn import NN
from utils.misc import ImageLoader
from utils.vqa.vqa import VQA
from utils.vqa.vqaEval import VQAEval

class BaseModel(object):
    def __init__(self, config):
        self.config = config
        self.is_train = True if config.phase == 'train' else False
        self.train_cnn = self.is_train and config.train_cnn
        self.image_loader = ImageLoader('./utils/ilsvrc_2012_mean.npy')
        self.image_shape = [224, 224, 3]
        self.global_step = tf.Variable(0,
                                       name = 'global_step',
                                       trainable = False)
        self.nn = NN(config)
        self.build()

    def build(self):
        raise NotImplementedError()

    def get_feed_dict(self, batch):
        raise NotImplementedError()

    def train(self, sess, train_data):
        """ Train the model using the VQA training data. """
        print("Training the model...")
        config = self.config

        if not os.path.exists(config.summary_dir):
            os.mkdir(config.summary_dir)
        train_writer = tf.summary.FileWriter(config.summary_dir, sess.graph)

        for epoch_no in tqdm(list(range(config.num_epochs)), desc='epoch'):
            for idx in tqdm(list(range(train_data.num_batches)), desc='batch'):
                batch = train_data.next_batch()
                feed_dict = self.get_feed_dict(batch)
                _, summary, global_step = sess.run([self.opt_op,
                                                    self.summary,
                                                    self.global_step],
                                                    feed_dict = feed_dict)
                if (global_step + 1) % config.save_period == 0:
                    self.save()
                train_writer.add_summary(summary, global_step)
            train_data.reset()

        print("Training complete.")

    def eval(self, sess, eval_gt_vqa, eval_data, vocabulary):
        """ Evaluate the model using the VQA validation data. """
        print("Evaluating the model...")
        config = self.config
        if not os.path.exists(config.eval_result_dir):
            os.mkdir(config.eval_result_dir)

        question_ids = eval_data.question_ids
        answers = []

        # Compute the answers to the questions
        idx = 0
        for k in tqdm(list(range(eval_data.num_batches))):
            batch = eval_data.next_batch()
            image_files, question_word_idxs, question_lens = batch
            feed_dict = self.get_feed_dict(batch)
            result = sess.run(self.prediction, feed_dict = feed_dict)

            fake_cnt = 0 if k<eval_data.num_batches-1 \
                         else eval_data.fake_count
            for l in range(eval_data.batch_size-fake_cnt):
                answer = vocabulary.words[result[l]]
                answers.append(answer)

                # Save the result in an image file
                if config.save_eval_result_as_image:
                    image_file = image_files[l]
                    image_name = image_file.split(os.sep)[-1]
                    image_name = os.path.splitext(image_name)[0]

                    q_word_idxs = question_word_idxs[l]
                    q_len = question_lens[l]
                    q_words = [vocabulary.words[q_word_idxs[i]] \
                        for i in range(q_len)]
                    if q_words[-1] != '?':
                        q_words.append('?')
                    Q = 'Q: ' + ''.join([' '+w if not w.startswith("'") \
                        and w not in string.punctuation \
                        else w for w in q_words]).strip()
                    A = 'A: ' + answer

                    image = plt.imread(image_file)
                    plt.imshow(image)
                    plt.axis('off')
                    plt.title(Q+'\n'+A)
                    plt.savefig(image_name + '_' + str(question_ids[idx]) + '_result.jpg')

                idx += 1

        results = [{'question_id': question_id, 'answer': answer} \
                   for question_id, answer in zip(question_ids, answers)]
        fp = open(config.eval_result_file, 'wb')
        json.dump(results, fp)
        fp.close()

        # Evaluate these answers
        eval_res_vqa = eval_gt_vqa.loadRes(config.eval_result_file,
                                           config.eval_question_file)
        scorer = VQAEval(eval_gt_vqa, eval_res_vqa)
        scorer.evaluate()
        print("Evaluation complete.")

    def test(self, sess, test_data, vocabulary):
        """ Test the model using any given images and questions. """
        print("Testing the model...")
        config = self.config

        if not os.path.exists(config.test_result_dir):
            os.mkdir(config.test_result_dir)

        question_ids = test_data.question_ids
        answers = []

        # Compute the answers to the questions
        idx = 0
        for k in tqdm(list(range(test_data.num_batches))):
            batch = test_data.next_batch()
            image_files, question_word_idxs, question_lens = batch
            feed_dict = self.get_feed_dict(batch)
            result = sess.run(self.prediction, feed_dict = feed_dict)

            fake_cnt = 0 if k < test_data.num_batches-1 \
                       else test_data.fake_count
            for l in range(test_data.batch_size-fake_cnt):
                answer = vocabulary.words[result[l]]
                answers.append(answer)

                # Save the result in an image file
                image_file = image_files[l]
                image_name = image_file.split(os.sep)[-1]
                image_name = os.path.splitext(image_name)[0]

                q_word_idxs = question_word_idxs[l]
                q_len = question_lens[l]
                q_words = [vocabulary.words[q_word_idxs[i]] \
                    for i in range(q_len)]
                if q_words[-1] != '?':
                    q_words.append('?')
                Q = 'Q: ' + ''.join([' '+w if not w.startswith("'") \
                    and w not in string.punctuation \
                    else w for w in q_words]).strip()
                A = 'A: ' + answer

                image = plt.imread(image_file)
                plt.imshow(image)
                plt.axis('off')
                plt.title(Q+'\n'+A)
                plt.savefig(os.path.join(config.test_result_dir, \
                    image_name + '_' + str(question_ids[idx]) \
                    + '_result.jpg'))

                idx += 1

        # Save the answers to a file
        test_info = pd.read_csv(config.temp_test_info_file)
        results = pd.DataFrame({'question_id': question_ids,
                                'answer': answers})
        results = pd.merge(test_info, results)
        results.to_csv(config.test_result_file)
        print("Testing complete.")

    def save(self):
        """ Save the model. """
        config = self.config
        data = {v.name: v.eval() for v in tf.global_variables()}
        save_path = os.path.join(config.save_dir, str(self.global_step.eval()))

        print((" Saving the model to %s..." % (save_path+".npy")))
        np.save(save_path, data)
        info_file = open(os.path.join(config.save_dir, "config.pickle"), "wb")
        config_ = copy.copy(config)
        config_.global_step = self.global_step.eval()
        pickle.dump(config_, info_file)
        info_file.close()
        print("Model saved.")

    def load(self, sess, model_file=None):
        """ Load the model. """
        config = self.config
        if model_file is not None:
            save_path = model_file
        else:
            info_path = os.path.join(config.save_dir, "config.pickle")
            info_file = open(info_path, "rb")
            config = pickle.load(info_file)
            global_step = config.global_step
            info_file.close()
            save_path = os.path.join(config.save_dir,
                                     str(global_step)+".npy")

        print("Loading the model from %s..." %save_path)
        data_dict = np.load(save_path).item()
        count = 0
        for v in tqdm(tf.global_variables()):
            if v.name in data_dict.keys():
                sess.run(v.assign(data_dict[v.name]))
                count += 1
        print("%d tensors loaded." %count)

    def load_cnn(self, session, data_path, ignore_missing=True):
        """ Load a pretrained CNN model. """
        print("Loading the CNN from %s..." %data_path)
        data_dict = np.load(data_path).item()
        count = 0
        for op_name in tqdm(data_dict):
            with tf.variable_scope(op_name, reuse=True):
                for param_name, data in data_dict[op_name].iteritems():
                    try:
                        var = tf.get_variable(param_name)
                        session.run(var.assign(data))
                        count += 1
                    except ValueError:
                        pass
        print("%d tensors loaded." %count)

(5)情景记忆模块episodic_memory.py

该文件是情景记忆模块文件:

import tensorflow as tf
from utils.nn import NN

class AttnGRU(object):
    """ Attention-based GRU (used by the Episodic Memory Module). """
    def __init__(self, config):
        self.nn = NN(config)
        self.num_units = config.num_gru_units

    def __call__(self, inputs, state, attention):
        with tf.variable_scope('attn_gru'):
            r_input = tf.concat([inputs, state], axis = 1)
            r_input = self.nn.dropout(r_input)
            r = self.nn.dense(r_input,
                              units = self.num_units,
                              activation = None,
                              use_bias = False,
                              name = 'fc1')
            b = tf.get_variable('fc1/bias',
                                shape = [self.num_units],
                                initializer = tf.constant_initializer(1.0))
            r = tf.nn.bias_add(r, b)
            r = tf.sigmoid(r)

            c_input = tf.concat([inputs, r*state], axis = 1)
            c_input = self.nn.dropout(c_input)
            c = self.nn.dense(c_input,
                              units = self.num_units,
                              activation = tf.tanh,
                              name = 'fc2')

            new_state = attention * c + (1 - attention) * state
        return new_state

class EpisodicMemory(object):
    """ Episodic Memory Module. """
    def __init__(self, config, num_facts, question, facts):
        self.nn = NN(config)
        self.num_units = config.num_gru_units
        self.num_facts = num_facts
        self.question = question
        self.facts = facts
        self.attention = config.attention
        if self.attention == 'gru':
            self.attn_gru = AttnGRU(config)

    def new_fact(self, memory):
        """ Get the context vector by using either soft attention or
            attention-based GRU. """
        fact_list = tf.unstack(self.facts, axis = 1)
        mixed_fact = tf.zeros_like(fact_list[0])

        with tf.variable_scope('attend'):
            attentions = self.attend(memory)

        if self.attention == 'gru':
            with tf.variable_scope('attn_gate') as scope:
                attentions = tf.unstack(attentions, axis = 1)
                for ctx, att in zip(fact_list, attentions):
                    mixed_fact = self.attn_gru(ctx,
                                               mixed_fact,
                                               tf.expand_dims(att, 1))
                    scope.reuse_variables()
        else:
            mixed_fact = tf.reduce_sum(self.facts*tf.expand_dims(attentions, 2),
                                       axis = 1)

        return mixed_fact

    def attend(self, memory):
        """ Get the attention weights. """
        c = self.facts
        q = tf.tile(tf.expand_dims(self.question, 1), [1, self.num_facts, 1])
        m = tf.tile(tf.expand_dims(memory, 1), [1, self.num_facts, 1])

        z = tf.concat([c*q, c*m, tf.abs(c-q), tf.abs(c-m)], 2)
        z = tf.reshape(z, [-1, 4*self.num_units])

        z = self.nn.dropout(z)
        z1 = self.nn.dense(z,
                           units = self.num_units,
                           activation = tf.tanh,
                           name = 'fc1')
        z1 = self.nn.dropout(z1)
        z2 = self.nn.dense(z1,
                           units = 1,
                           activation = None,
                           use_bias = False,
                           name = 'fc2')
        z2 = tf.reshape(z2, [-1, self.num_facts])

        attentions = tf.nn.softmax(z2)
        return attentions

(6)模型文件model.py

该文件是模型文件,是根据前面的模块来集成问答模型:

import tensorflow as tf
import numpy as np

from base_model import BaseModel
from episodic_memory import EpisodicMemory

class QuestionAnswerer(BaseModel):
    def build(self):
        """ Build the model. """
        self.build_cnn()
        self.build_rnn()
        if self.is_train:
            self.build_optimizer()
            self.build_summary()

    def build_cnn(self):
        """ Build the CNN. """
        print("Building the CNN...")
        if self.config.cnn =='vgg16':
            self.build_vgg16()
        else:
            self.build_resnet50()
        print("CNN built.")

    def build_vgg16(self):
        """ Build the VGG16 net. """
        config = self.config

        images = tf.placeholder(
            dtype = tf.float32,
            shape = [config.batch_size] + self.image_shape)

        conv1_1_feats = self.nn.conv2d(images, 64, name = 'conv1_1')
        conv1_2_feats = self.nn.conv2d(conv1_1_feats, 64, name = 'conv1_2')
        pool1_feats = self.nn.max_pool2d(conv1_2_feats, name = 'pool1')

        conv2_1_feats = self.nn.conv2d(pool1_feats, 128, name = 'conv2_1')
        conv2_2_feats = self.nn.conv2d(conv2_1_feats, 128, name = 'conv2_2')
        pool2_feats = self.nn.max_pool2d(conv2_2_feats, name = 'pool2')

        conv3_1_feats = self.nn.conv2d(pool2_feats, 256, name = 'conv3_1')
        conv3_2_feats = self.nn.conv2d(conv3_1_feats, 256, name = 'conv3_2')
        conv3_3_feats = self.nn.conv2d(conv3_2_feats, 256, name = 'conv3_3')
        pool3_feats = self.nn.max_pool2d(conv3_3_feats, name = 'pool3')

        conv4_1_feats = self.nn.conv2d(pool3_feats, 512, name = 'conv4_1')
        conv4_2_feats = self.nn.conv2d(conv4_1_feats, 512, name = 'conv4_2')
        conv4_3_feats = self.nn.conv2d(conv4_2_feats, 512, name = 'conv4_3')
        pool4_feats = self.nn.max_pool2d(conv4_3_feats, name = 'pool4')

        conv5_1_feats = self.nn.conv2d(pool4_feats, 512, name = 'conv5_1')
        conv5_2_feats = self.nn.conv2d(conv5_1_feats, 512, name = 'conv5_2')
        conv5_3_feats = self.nn.conv2d(conv5_2_feats, 512, name = 'conv5_3')

        self.permutation = self.get_permutation(14, 14)
        conv5_3_feats_flat = self.flatten_feats(conv5_3_feats, 512)
        self.conv_feats = conv5_3_feats_flat
        self.conv_feat_shape = [196, 512]
        self.images = images

    def build_resnet50(self):
        """ Build the ResNet50. """
        config = self.config

        images = tf.placeholder(
            dtype = tf.float32,
            shape = [config.batch_size] + self.image_shape)

        conv1_feats = self.nn.conv2d(images,
                                  filters = 64,
                                  kernel_size = (7, 7),
                                  strides = (2, 2),
                                  activation = None,
                                  name = 'conv1')
        conv1_feats = self.nn.batch_norm(conv1_feats, 'bn_conv1')
        conv1_feats = tf.nn.relu(conv1_feats)
        pool1_feats = self.nn.max_pool2d(conv1_feats,
                                      pool_size = (3, 3),
                                      strides = (2, 2),
                                      name = 'pool1')

        res2a_feats = self.resnet_block(pool1_feats, 'res2a', 'bn2a', 64, 1)
        res2b_feats = self.resnet_block2(res2a_feats, 'res2b', 'bn2b', 64)
        res2c_feats = self.resnet_block2(res2b_feats, 'res2c', 'bn2c', 64)

        res3a_feats = self.resnet_block(res2c_feats, 'res3a', 'bn3a', 128)
        res3b_feats = self.resnet_block2(res3a_feats, 'res3b', 'bn3b', 128)
        res3c_feats = self.resnet_block2(res3b_feats, 'res3c', 'bn3c', 128)
        res3d_feats = self.resnet_block2(res3c_feats, 'res3d', 'bn3d', 128)

        res4a_feats = self.resnet_block(res3d_feats, 'res4a', 'bn4a', 256)
        res4b_feats = self.resnet_block2(res4a_feats, 'res4b', 'bn4b', 256)
        res4c_feats = self.resnet_block2(res4b_feats, 'res4c', 'bn4c', 256)
        res4d_feats = self.resnet_block2(res4c_feats, 'res4d', 'bn4d', 256)
        res4e_feats = self.resnet_block2(res4d_feats, 'res4e', 'bn4e', 256)
        res4f_feats = self.resnet_block2(res4e_feats, 'res4f', 'bn4f', 256)

        res5a_feats = self.resnet_block(res4f_feats, 'res5a', 'bn5a', 512)
        res5b_feats = self.resnet_block2(res5a_feats, 'res5b', 'bn5b', 512)
        res5c_feats = self.resnet_block2(res5b_feats, 'res5c', 'bn5c', 512)

        self.permutation = self.get_permutation(7, 7)
        res5c_feats_flat = self.flatten_feats(res5c_feats, 2048)
        self.conv_feats = res5c_feats_flat
        self.conv_feat_shape = [49, 2048]
        self.images = images

    def resnet_block(self, inputs, name1, name2, c, s=2):
        """ A basic block of ResNet. """
        branch1_feats = self.nn.conv2d(inputs,
                                    filters = 4*c,
                                    kernel_size = (1, 1),
                                    strides = (s, s),
                                    activation = None,
                                    use_bias = False,
                                    name = name1+'_branch1')
        branch1_feats = self.nn.batch_norm(branch1_feats, name2+'_branch1')

        branch2a_feats = self.nn.conv2d(inputs,
                                     filters = c,
                                     kernel_size = (1, 1),
                                     strides = (s, s),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2a')
        branch2a_feats = self.nn.batch_norm(branch2a_feats, name2+'_branch2a')
        branch2a_feats = tf.nn.relu(branch2a_feats)

        branch2b_feats = self.nn.conv2d(branch2a_feats,
                                     filters = c,
                                     kernel_size = (3, 3),
                                     strides = (1, 1),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2b')
        branch2b_feats = self.nn.batch_norm(branch2b_feats, name2+'_branch2b')
        branch2b_feats = tf.nn.relu(branch2b_feats)

        branch2c_feats = self.nn.conv2d(branch2b_feats,
                                     filters = 4*c,
                                     kernel_size = (1, 1),
                                     strides = (1, 1),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2c')
        branch2c_feats = self.nn.batch_norm(branch2c_feats, name2+'_branch2c')

        outputs = branch1_feats + branch2c_feats
        outputs = tf.nn.relu(outputs)
        return outputs

    def resnet_block2(self, inputs, name1, name2, c):
        """ Another basic block of ResNet. """
        branch2a_feats = self.nn.conv2d(inputs,
                                     filters = c,
                                     kernel_size = (1, 1),
                                     strides = (1, 1),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2a')
        branch2a_feats = self.nn.batch_norm(branch2a_feats, name2+'_branch2a',)
        branch2a_feats = tf.nn.relu(branch2a_feats)

        branch2b_feats = self.nn.conv2d(branch2a_feats,
                                     filters = c,
                                     kernel_size = (3, 3),
                                     strides = (1, 1),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2b')
        branch2b_feats = self.nn.batch_norm(branch2b_feats, name2+'_branch2b')
        branch2b_feats = tf.nn.relu(branch2b_feats)

        branch2c_feats = self.nn.conv2d(branch2b_feats,
                                     filters = 4*c,
                                     kernel_size = (1, 1),
                                     strides = (1, 1),
                                     activation = None,
                                     use_bias = False,
                                     name = name1+'_branch2c')
        branch2c_feats = self.nn.batch_norm(branch2c_feats, name2+'_branch2c')

        outputs = inputs + branch2c_feats
        outputs = tf.nn.relu(outputs)
        return outputs

    def get_permutation(self, height, width):
        """ Get the permutation corresponding to the snake-like walk decribed \
           in the paper. Used to flatten the convolutional feats. """
        permutation = np.zeros(height*width, np.int32)
        for i in range(height):
            for j in range(width):
                permutation[i*width+j] = i*width+j if i%2==0  \
                                         else (i+1)*width-j-1
        return permutation

    def flatten_feats(self, feats, channels):
        """ Flatten the feats. """
        temp1 = tf.reshape(feats, [self.config.batch_size, -1, channels])
        temp1 = tf.transpose(temp1, [1, 0, 2])
        temp2 = tf.gather(temp1, self.permutation)
        temp2 = tf.transpose(temp2, [1, 0, 2])
        return temp2

    def build_rnn(self):
        """ Build the RNN. """
        print("Building the RNN...")
        config = self.config

        facts = self.conv_feats
        num_facts, dim_fact = self.conv_feat_shape

        # Setup the placeholders
        question_word_idxs = tf.placeholder(
            dtype = tf.int32,
            shape = [config.batch_size, config.max_question_length])
        question_lens = tf.placeholder(
            dtype = tf.int32,
            shape = [config.batch_size])
        if self.is_train:
            answer_idxs = tf.placeholder(
                dtype = tf.int32,
                shape = [config.batch_size])
            if config.question_encoding == 'positional':
                position_weights = tf.placeholder(
                    dtype = tf.float32,
                    shape = [config.batch_size, \
                             config.max_question_length, \
                             config.dim_embedding])

        # Setup the word embedding
        with tf.variable_scope("word_embedding"):
            embedding_matrix = tf.get_variable(
                name = 'weights',
                shape = [config.vocabulary_size, config.dim_embedding],
                initializer = self.nn.fc_kernel_initializer,
                regularizer = self.nn.fc_kernel_regularizer,
                trainable = self.is_train)

        # Encode the questions
        with tf.variable_scope('question_encoding'):
            question_embeddings = tf.nn.embedding_lookup(
                embedding_matrix,
                question_word_idxs)

            if config.question_encoding == 'positional':
                # use positional encoding
                self.build_position_weights()
                question_encodings = question_embeddings * position_weights
                question_encodings = tf.reduce_sum(question_encodings,
                                                    axis = 1)
            else:
                # use GRU encoding
                outputs, _ = tf.nn.dynamic_rnn(
                    self.nn.gru(),
                    inputs = question_embeddings,
                    dtype = tf.float32)

                question_encodings = []
                for k in range(config.batch_size):
                    question_encoding = tf.slice(outputs,
                                                 [k, question_lens[k]-1, 0],
                                                 [1, 1, config.num_gru_units])
                    question_encodings.append(tf.squeeze(question_encoding))
                question_encodings = tf.stack(question_encodings, axis = 0)

        # Encode the facts
        with tf.variable_scope('input_fusion'):
            if config.embed_fact:
                facts = tf.reshape(facts, [-1, dim_fact])
                facts = self.nn.dropout(facts)
                facts = self.nn.dense(
                    facts,
                    units = config.dim_embedding,
                    activation = tf.tanh,
                    name = 'fc')
                facts = tf.reshape(facts, [-1, num_facts, config.dim_embedding])

            outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                self.nn.gru(),
                self.nn.gru(),
                inputs = facts,
                dtype = tf.float32)
            outputs_fw, outputs_bw = outputs
            fact_encodings = outputs_fw + outputs_bw

        # Episodic Memory Update
        with tf.variable_scope('episodic_memory'):
            episode = EpisodicMemory(config,
                                     num_facts,
                                     question_encodings,
                                     fact_encodings)
            memory = tf.identity(question_encodings)

            if config.tie_memory_weight:
                scope_list = ['layer'] * config.memory_step
            else:
                scope_list = ['layer'+str(t) for t in range(config.memory_step)]

            for t in range(config.memory_step):
                with tf.variable_scope(scope_list[t], reuse = tf.AUTO_REUSE):
                    fact = episode.new_fact(memory)
                    if config.memory_update == 'gru':
                        gru = self.nn.gru()
                        memory = gru(fact, memory)[0]
                    else:
                        expanded_memory = tf.concat(
                            [memory, fact, question_encodings],
                            axis = 1)
                        expanded_memory = self.nn.dropout(expanded_memory)
                        memory = self.nn.dense(
                            expanded_memory,
                            units = config.num_gru_units,
                            activation = tf.nn.relu,
                            name = 'fc')

        # Compute the result
        with tf.variable_scope('result'):
            expanded_memory = tf.concat([memory, question_encodings],
                                        axis = 1)
            expanded_memory = self.nn.dropout(expanded_memory)
            logits = self.nn.dense(expanded_memory,
                                   units = config.vocabulary_size,
                                   activation = None,
                                   name = 'logits')
            prediction = tf.argmax(logits, axis = 1)

        # Compute the loss and accuracy if necessary
        if self.is_train:
            cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = answer_idxs,
                logits = logits)
            cross_entropy_loss = tf.reduce_mean(cross_entropy_loss)
            reg_loss = tf.losses.get_regularization_loss()
            total_loss = cross_entropy_loss + reg_loss

            ground_truth = tf.cast(answer_idxs, tf.int64)
            prediction_correct = tf.where(
                tf.equal(prediction, ground_truth),
                tf.cast(tf.ones_like(prediction), tf.float32),
                tf.cast(tf.zeros_like(prediction), tf.float32))
            accuracy = tf.reduce_mean(prediction_correct)

        self.question_word_idxs = question_word_idxs
        self.question_lens = question_lens
        self.prediction = prediction

        if self.is_train:
            self.answer_idxs = answer_idxs
            if config.question_encoding == 'positional':
                self.position_weights = position_weights
            self.total_loss = total_loss
            self.cross_entropy_loss = cross_entropy_loss
            self.reg_loss = reg_loss
            self.accuracy = accuracy

        print("RNN built.")

    def build_position_weights(self):
        """ Setup the weights for the positional encoding of questions. """
        config = self.config
        D = config.dim_embedding
        pos_weights = []
        for M in range(config.max_question_length):
            cur_pos_weights = []
            for j in range(config.max_question_length):
                if j <= M:
                    temp = [1.0-(j+1.0)/(M+1.0) \
                            -((d+1.0)/D)*(1-2.0*(j+1.0)/(M+1.0)) \
                            for d in range(D)]
                else:
                    temp = [0.0] * D
                cur_pos_weights.append(temp)
            pos_weights.append(cur_pos_weights)
        self.pos_weights = np.array(pos_weights, np.float32)

    def build_optimizer(self):
        """ Setup the training operation. """
        config = self.config

        learning_rate = tf.constant(config.initial_learning_rate)
        if config.learning_rate_decay_factor < 1.0:
            def _learning_rate_decay_fn(learning_rate, global_step):
                return tf.train.exponential_decay(
                    learning_rate,
                    global_step,
                    decay_steps = config.num_steps_per_decay,
                    decay_rate = config.learning_rate_decay_factor,
                    staircase = True)
            learning_rate_decay_fn = _learning_rate_decay_fn
        else:
            learning_rate_decay_fn = None

        with tf.variable_scope('optimizer', reuse = tf.AUTO_REUSE):
            if config.optimizer == 'Adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate = config.initial_learning_rate,
                    beta1 = config.beta1,
                    beta2 = config.beta2,
                    epsilon = config.epsilon
                    )
            elif config.optimizer == 'RMSProp':
                optimizer = tf.train.RMSPropOptimizer(
                    learning_rate = config.initial_learning_rate,
                    decay = config.decay,
                    momentum = config.momentum,
                    centered = config.centered,
                    epsilon = config.epsilon
                )
            elif config.optimizer == 'Momentum':
                optimizer = tf.train.MomentumOptimizer(
                    learning_rate = config.initial_learning_rate,
                    momentum = config.momentum,
                    use_nesterov = config.use_nesterov
                )
            else:
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate = config.initial_learning_rate
                )

            opt_op = tf.contrib.layers.optimize_loss(
                loss = self.total_loss,
                global_step = self.global_step,
                learning_rate = learning_rate,
                optimizer = optimizer,
                clip_gradients = config.clip_gradients,
                learning_rate_decay_fn = learning_rate_decay_fn)

        self.opt_op = opt_op

    def build_summary(self):
        """ Build the summary (for TensorBoard visualization). """
        with tf.name_scope("variables"):
            for var in tf.trainable_variables():
                with tf.name_scope(var.name[:var.name.find(":")]):
                    self.variable_summary(var)

        with tf.name_scope("metrics"):
            tf.summary.scalar("cross_entropy_loss", self.cross_entropy_loss)
            tf.summary.scalar("reg_loss", self.reg_loss)
            tf.summary.scalar("total_loss", self.total_loss)
            tf.summary.scalar("accuracy", self.accuracy)

        self.summary = tf.summary.merge_all()

    def variable_summary(self, var):
        """ Build the summary for a variable. """
        mean = tf.reduce_mean(var)
        tf.summary.scalar('mean', mean)
        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        tf.summary.scalar('stddev', stddev)
        tf.summary.scalar('max', tf.reduce_max(var))
        tf.summary.scalar('min', tf.reduce_min(var))
        tf.summary.histogram('histogram', var)

    def get_feed_dict(self, batch):
        """ Get the feed dictionary for the current batch. """
        config = self.config
        if self.is_train:
            # training phase
            image_files, question_word_idxs, question_lens, answer_idxs = batch
            images = self.image_loader.load_images(image_files)
            if config.question_encoding == 'positional':
                position_weights = [self.pos_weights[question_lens[i]-1, :, :]
                                    for i in range(config.batch_size)]
                position_weights = np.array(position_weights, np.float32)
                return {self.images: images,
                        self.question_word_idxs: question_word_idxs,
                        self.question_lens: question_lens,
                        self.answer_idxs: answer_idxs,
                        self.position_weights: position_weights}
            else:
                return {self.images: images,
                        self.question_word_idxs: question_word_idxs,
                        self.question_lens: question_lens,
                        self.answer_idxs: answer_idxs}
        else:
            # evaluation or testing phase
            image_files, question_word_idxs, question_lens = batch
            images = self.image_loader.load_images(image_files)
            return {self.images: images,
                    self.question_word_idxs: question_word_idxs,
                    self.question_lens: question_lens}

(7)主文件main.py

主文件中确定运行的阶段,是训练还是测试还是评估?

#!/usr/bin/python
import tensorflow as tf

from config import Config
from model import QuestionAnswerer
from dataset import prepare_train_data, prepare_eval_data, prepare_test_data

FLAGS = tf.app.flags.FLAGS

tf.flags.DEFINE_string('phase', 'train',
                       'The phase can be train, eval or test')

tf.flags.DEFINE_boolean('load', False,
                        'Turn on to load a pretrained model from either \
                        the latest checkpoint or a specified file')

tf.flags.DEFINE_string('model_file', None,
                       'If sepcified, load a pretrained model from this file')

tf.flags.DEFINE_boolean('load_cnn', False,
                        'Turn on to load a pretrained CNN model')

tf.flags.DEFINE_string('cnn_model_file', './vgg16_no_fc.npy',
                       'File containing a pretrained CNN model')

tf.flags.DEFINE_boolean('train_cnn', False,
                        'Turn on to train both CNN and RNN. \
                         Otherwise, only RNN is trained')

def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn

    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data, config = prepare_train_data(config)
            model = QuestionAnswerer(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            vqa, data, vocabulary, config = prepare_eval_data(config)
            model = QuestionAnswerer(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, vqa, data, vocabulary)

        else:
            # testing phase
            data, vocabulary, config = prepare_test_data(config)
            model = QuestionAnswerer(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)

if __name__ == '__main__':
    tf.app.run()

4. 实验训练

首先是训练:

第一次训练时候的参数设置,phase设置为train即可。

由于我的GPU(1050TI)比较小,直接运行会内存溢出,于是batch size改成了4,数据量比较大,训练时间很长,大概训练了6天吧(还没训练完,6天只训练了一个半epoch),下面这张图是训练了一天的景象:

下面是第四天第一个epoch快训练完的景象:

如果是训练了一半,想要接着之前的结果继续训练,将phase设置为train,然后将model_file设置为'./models/xxxxxx.npy'。

如果想要监测训练过程,可以在cmd中运行:tensorboard --logdir='./summary/'

如果想一起训练CNN和RNN,则打开--train_cnn即可,关闭则只训练RNN。

接下来是模型评价:

需要设置两个参数:--phase=eval --model_file='./models/xxxxxx.npy'

结果将会保存在'val/results.json'中。

最后就是测试:

可以用训练模型来回答JPEG图像的任何问题,将图像放在路径‘test/images’下,然后创建一个CSV(内容包含图像,问题,问题id),放在路径‘test’下,最后按照下面的参数设置运行:--phase=test --model_file='./models/xxxxxx.npy'

结果最终保存在‘test/results’里面。

5. 实验结果

训练时每1000个batch保存一次,我大概训练了74000+个batch,最后评估和测试用的是72999.npy:

模型的测试时间比较久,大概一共有40000多张图像,测试时间大约需要6个小时:

模型测试做到后面发现json的写入报错了,后面我考虑将它改写为csv保存结果。

先来看一下测试过程吧,我选了3张图,一张是航空影像拍摄的棒球场,一张是卡通图象皮卡丘,一张是足球运动图:

先要做一个csv文件,写入图像的路径和你想问的问题,放在路径“./test/question.csv”下:

来看一下回答结果吧,预想的应该是只有足球能回答正确:

第一张图问图像有多少个棒球场,回答两个,正确,第二张图问这个黄色的可爱小动物的名字是啥,回答长颈鹿,错误。

最后一个问题是他们在做什么运动?答案是飞碟,也回答错了。

四、小结

1. VQA还挺难的,不管是模型编写还是数据读取,都要非常细心。

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

全部梭哈迟早暴富

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值