show-attend-and-tell-tensorflow源码解读:preprocess.py

show,attend, and tell是image caption领域的经典论文,image caption即“看图说话”,它连接了计算机视觉和自然语言处理。本篇博客主要解读show,attend, and tell的预处理代码,计算机视觉需要预处理流程,自然语言处理也需要把自然语言进行词编码与词嵌入。总体而言,该任务的预处理流程是比较复杂的。

在gitub上搜索show-attend-and-tell-tensorflow,排名第一的仓库就是笔者该系列解读的代码。在深入研究解读代码之前,有必要做一点理论介绍和阐述一些模型细节。

  1. <show, attend, and tell>采用的是encoder-decoder架构;encoder是CNN,decoder是RNN,二者不能同时训练,否则the noise in the initial gradients coming from the LSTM into the image model corrupted the CNN and would never recover(show and tell: lessons learned from the 2015 MSCOCO image captioning challenge)。所以遇到的代码中,普遍方法是先使用经过预训练的CNN模型提取feature map/vector(特征图/特征向量,不同模型会有些微差异),再将feature map/vector作为decoder的输入,参与模型的计算。
  2. 对于decoder而言,它负责将feature map/vector 解码为 人能看得懂的自然语言。但计算机无法直接处理自然语言,所以我们需要对数据集中涉及的单词进行one-hot编码及word embedding(词嵌入)

对涉及到的预处理做的简单介绍到此结束,现在我们正式开始解读预处理代码。预处理代码绝不止preprocess.py一个文件,但它作为预处理程序的汇总,我们有必要通过它来更全面深入地掌握预处理流程的全貌。

preprocess.py的导入信息如下:

from scipy import ndimage
"""
scipy.ndimage: Multi-dimentional image processing(多维图像处理包)
用于多为图像处理的各种功能,包含:
1.Filters: 过滤器
2.Fourier filters: 傅里叶过滤器
3.Interpolation: 图像的插值、旋转及仿射变换
4.Measurements: 图像相关信息的测量
5.Morphology: 形态学图像处理
更强大的图像处理库包括:opencv, scikit-image等
"""

from collections import Counter
"""
collections模块包含多种集合类
1.namedtuple: 可以对tuple的某个维度命名,并且还可以根据命名获得该维度的值
2.deque: 适用于队列和栈,可以实现高效的插入和删除。
3.defaultdict: 使用dict时,如果引用的key值不存在,就会抛出keyerror;如果key不存在时,希望能返回一个默认值而非抛出错误,就可以用defaultdict。
4.OrderDict: 使用dict时,key是无序的,在对dict做迭代时,我们无法确定key的顺序;要保持key的顺序,可以用OrderDict
5.Counter: 是一个简单的计数器,统计字符出现的个数,它是dict的一个子类
"""
from core.vggnet import Vgg19
"""
从imagenet-vgg-verydeep-19.mat中获取了预训练参数,并用其构造了vgg19模型的计算流程(模型/类)
"""
from core.utils import *
import tensorflow as tf
import numpy as np
import pandas as pd
import hickle
import os
import json
"""
深度学习会处理大量输入数据,也会输出大量数据,这些数据全放在CPU/GPU显然是不现实的。
通常的做法是先保存到硬盘文件中,待到需要的时候再加载拿来用。
这中间涉及到文件路径的增删改查(与操作系统进行交互),因此需要os模块。
CPU/GPU中数据的保存、硬盘文件的读取,可以通过json, pickle, hickle等模块/库来实现,处理的文件格式分别是json, pkl, hkl。
"""

我们刚刚提到preprocess.py是预处理程序的汇总,所有涉及预处理的组件都在这里应用。因为main()函数就是preprocess.py文件的入口,所以我们先从main()函数开始。

深度学习通常成批处理输入,这里的CNN从图像中提取feature map/vector也不例外,batch_size(100)表示一批图像的数目。vgg_model_path=’./data/imagenet-vgg-verydeep-19.mat’该路径下的文件

  1. 深度学习通常成批处理输入,这里CNN从图像中提取feature map/vector也不例外,batch_size(100)表示一批图像的数目。
  2. vgg_model_path该路径下的文件保存着预训练的vgg模型的参数。core/vggnet.py中的Vgg19类,也就是core.vggnet模块里的Vgg19类,就需要使用这些参数信息来构建vgg19网络。具体如何构建看后续博客解读分析。
  3. max_length(15)与标签语句有关,它规定了一个标签语句最多包含多少单词,超过15就把对应句子删掉。
  4. 在前面我们提到了"需要对数据集中涉及的单词进行one-hot编码及word embedding(词嵌入)"。one-hot编码:把训练集中的单词编成一个词汇表,词汇表中的每个单词都是独一无二的,在词汇表的位置也是独一无二的,假设某个单词在词汇表的位置为i,那么它可以用一个固定长度的向量表示,向量的第i维为1,其它维都为0。但在实际操作中,训练集中的有的单词出现频次过低,那么我们可以将其舍去,不列入词汇表中。当出现频次< word_count_threshold就舍去,这里word_count_threshold=1,即出现的单词全都列入词汇表中。
  5. coco2014数据集中的图片有自己的文件名,同时也有自己唯一的图片id信息,后者在annotation文件中有记录。
def main():
    # batch size for extracting feature vectors from vggnet
    batch_size = 100    # 一次提取100幅图像的feature vectors
    # maximum length of caption (number of word). if caption is longer than max_length, deleted.
    max_length = 15    # 标签语句最长15个单词,超过15个单词的语句删掉
    # if word occurs less than word_count_threshold in training dataset, the word index is special unknown token.
    word_count_threshold = 1 # 如果训练集中某个单词出现次数小于1,那就设为null(一个特殊的token)
    # vgg model path
    vgg_model_path = './data/imagenet-vgg-verydeep-19.mat'
    
    # about 80000 images and 400000 captions for train dataset
    train_dataset = _process_caption_data(caption_file='data/annotations/captions_train2014.json, image_dir='image/train2014_resized', max_length=max_length)
    # 有图像文件夹image_dir,有包含标签语句和图像与标签的连接信息的caption_file,这个函数(后面详细介绍)事实上构建了训练集变量,另外一点:./data/ == data/

    # about 40000 images and 200000 captions
    val_dataset = _process_caption_data(caption_file='data/annotations/captions_val2014.json', image_dir='image/val_resized', max_length=max_length)
    # 这里构建了验证集变量

    # about 4000 images and 20000 captions for val / test dataset
    val_cutoff = int(0.1 * len(val_dataset))
    test_cutoff = int(0.2 * len(val_dataset))
    print('Finished processing caption data')

    save_pickle(train_dataset, 'data/train/train.annotations.pkl')
    save_pickle(val_dataset[:val_cutoff], 'data/val/val.annotations.pkl')
    save_pickle(val_dataset[val_cutoff:test_cutoff].reset_index(drop=True), 'data/test/test.annotations.pkl')
    """
        这里save_pickle()函数与pickle模块有关,pickle模块保存的文件后缀名都是pkl,save_pickle()是对pickle.dump()函数的扩展,它的定义在core.utils模块中(前面导入模块中已经写了)。
        reset_index()方法的全称是pandas.DataFrame.reset_index(),用来防止原索引变成数据列。可见_process_caption_data返回的结果是pd.DataFrame类的实例,但疑点是为什么前两个没用该方法?
        从这儿开始,对上面得到的train, val, test三个文件,逐个执行相关操作。
    """
    
    for split in ['train', 'val', 'test']:
        annotations = load_pickle('./data/%s/%s.annotations.pkl' % (split, split))
        # load_pickle()与save_pickle()情形相似,都位于core.utils模块中(core/utils.py文件中),都是对pickle模块中的函数进行扩展,不同之处在于load_pickle()扩展的是pickle.load()
        
        if split == 'train':
            word_to_idx = _build_vocab(annotations=annotations, threshold=word_count_threshold)
            # 在training阶段,制作词汇表,方便后续的one-hot词编码和词嵌入。
            save_pickle(word_to_idx, '.data/%s/word_to_idx.pkl' % split)    # 把词汇表保存起来
        captions = _build_caption_vector(annotations=annotations, word_to_idx=word_to_idx, max_length=max_length)
        # 制作好词汇表后,对整个句子进行编码
        save_pickle(captions, './data/%s/%s.captions.pkl' % (split, split))    # 对句子编码向量保存起来。
        
        file_names, id_to_idx = _build_file_names(annotations)
        save_pickle(file_names, './data/%s/%s.file.names.pkl' % (split, split))
        image_idxs = _build_image_idxs(annotations, id_to_idx)
        save_pickle(image_idxs, './data/%s/%s.image.idxs.pkl' % (split, split))
        """这四句暂时不清楚具体干了啥,但应该是提取了图片文件名,图片id,标签语句,标签编码之间的关系"""
        # prepare reference captions to compute bleu scores later
        """
        	这部分代码用了前面生成的文件中的参数信息,我们暂时不清楚这些文件的生成细节,所以留在后面详细研究
        """
        image_ids = {}
        feature_to_captions = {}
        i = -1
        for caption, image_id in zip(annotations['caption'], annotations['image_id']):
            if not image_id in image_ids:
                image_ids[image_id] = 0
                i += 1
                feature_to_captions[i] = []
            feature_to_captions[i].append(caption.lower() + ' .')
        save_pickle(feature_to_captions, './data/%s/%s.references.pkl' % (split, split))
        print("finished building %s caption dataset" % split)
	# extract conv5_3 feature vectors
    vggnet = Vgg19(vgg_model_path)
    # 加载预训练的模型参数
    vggnet.build()
    # 加载后构建vgg19模型,得到完整的计算流程
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        for split in ['train', 'val', 'test']:
            anno_path = './data/%s/%s.annotations.pkl' % (split, split)
            save_path = './data/%s/%s.features.hkl' % (split, split)
            annotations = load_pickle(anno_path)
            image_path = list(annotations['file_name'].unique())
            n_example = len(image_path)
            
            all_feats = np.ndarray([n_example, 196, 512], dtype=np.float32)
            
            for start, end in zip(range(0, n_example, batch_size), range(batch_size, n_example+batch_size, batch_size)):
                image_batch_file = image_path[staart:end]
                image_batch = np.array(map(lambda x: ndimage.imread(x, mode='RGB'), image_batch_file)).astype(np.float32)
                feats = sess.run(vggnet.features, feed_dict={vggnet.images:image_batch})
                all_feats[start:end, :] = feats
                print("Processed %d %s features.." % (end, split))
            
            # use hickle to save huge feature vectors
            hickle.dump(all_feats, save_path)
            print("Saved %s.." % (save_path))

接下来是main()函数中第一个调用的函数__process_caption_data(),函数名前面包括下划线通常是指这个函数不希望被本模块以外的函数调用。

def _process_caption_data(caption_file, image_dir, max_length):
	with open(caption_file) as f:
		caption_data = json.load(f)
	# id_to_filename is a dictionary such as {image_id: filename}
	id_to_filename = {image['id']:image['file_name'] for image in caption_data['images']}
	"""
		搜索"coco data format"就能找到官方文档,caption_data由键-值对构成,
		'images'键对应的值是由image结构体组成的列表/数组(总之可以迭代),这里程序把列表/数组中每个元素也命名为image了。
		image结构体也是由键-值对构成,键包括:
		"id", "width", "height", "file_name", "license", "flickr_url", "coco_url", "data_captured"
		这里使用了"id", "filename"两个键。
	"""
	# data is a list of dictionary which contains 'captions', 'filename' and 'image_id' as key.
	"""
		上面这个注释写错了,是'caption'不是'captions',虽只有一字之差,但'captions'是读取了captions_train(val)2014.json后能直接访问的键,要想访问'caption'键,需要先读取captions_train(val)2014.json文件后访问'annotations'键,然后对'annotations'键的值进行遍历(值是一个列表),列表中的元素才能访问'caption'键
	"""
	data = []
	for annotation in caption_data['annotations']:
		image_id = annotation['image_id']
		annotation['file_name'] = os.path.join(image_dir, id_to_filename[image_id])
		data += [annotation]
	"""
		与上面的'images'键一样,'annotations'也是caption_data可以直接访问的键。
		'annotations'键对应的值是由annotation结构体组成的列表/数组(总之可以迭代),这里程序把列表/数组中每个元素也命名为annotation了。
		annotation结构体也是由键-值对构成,键包括:"id", "image_id", "caption"
		
		image结构体中的file_name<--->image结构体中id==annotation结构体中image_id
		通过上述连接关系,在annotation结构体中添加键"file_name",并把每个annotation结构实例添加进data列表里。
	"""
	# convert to pandas dataframe (for later visualization or debugging)
	caption_data = pd.DataFrame.from_dict(data)
	# 这时候的caption_data已经不是原先的open()函数后的那个了。
	del caption_data['id']
	# caption_data删除对'id'列的引用,此时只剩'caption', 'file_name', 'image_id'列了。
	caption_data.sort_values(by='image_id', inplace=True)
	"""
		pandas.DataFrame.sort_values(by='xxx', inplace=True)
		参数by对应的列的元素类型必须为数字类型,这里'image_id'列元素类型为int
		另外inplace=True时,为就地排序,无需再赋值就能保存排序后的结果
	"""
	caption_data = caption_data.reset_index(drop=True)
	# 我的pandas版本是0.24.2,测试了pd.DataFrame.sort_value(),发现舍弃reset.index()也没问题

	del_idx = []
	for i, caption in enumerate(caption_data['caption']):
		# 上面提到caption_data现在包含的列有:'captions', 'file_name', 'image_id'。
		caption = caption.replace('.', '').replace(',','').replace("'","").replace('"','')
		caption = caption.replace('&', 'and').replace('(','').replace(")","").replace('-',' ')
		# 最后一个replace函数的第二个参数为什么是空格字符,和其他的不一样?
		caption = " ".join(caption.split()) # replace multiple spaces
		
		caption_data.set_value(i, 'caption', caption.lower())
		# 更改caption_data的第i行,'caption'列的值,将其由caption变为caption.lower()
		if len(caption.split(" ")) > max_length:
			# 虽然切分了,但并没有赋值,所以caption指向的数据对象值没有发生改变
			del_idx.append(i)	# 这里也可以看出idx是index的意思,表示DataFrame结构的行索引
	# delete captions if size is larger than max_length
	print "The number of captions before deletion: %d" % len(caption_data)
	caption_data = caption_data.drop(caption_data.index[del_idx])
	# 或许是版本原因,或许就是有问题,搜索引擎搜不出pd.DataFrame.index()函数,然后在代码中也找不到定义?
	# 我认为这里应该写作:caption_data = caption_data.drop(index=del_idx)
	caption_data = caption_data.reset_index(drop=True)
	print "The number of captions after deletion: %d" % len(caption_data)
	return caption_data

简而言之,_process_caption_data处理了训练集和测试集的json文件,通过访问键’images’和键’captions’,获取了图片的id信息,图片文件名和对应的描述,把它处理成pd.DataFrame类实例,删除其中过长的文本描述后返回。

现在我们来解读下一个函数: _build_vocab()

def _build_vocab(annotations, threshold=1):
	# 在main()函数中,_build_vocab()只调用了一次
	# 包含训练集全部信息的annotations/captions_train2014.json文件,经过_process_caption_data()函数处理,
	# 得到的结果是:包含图片id信息'image_id',图片文件名'file_name', 图片描述'caption'三个数据列的pandas.DataFrame类实例。而后把它存成pkl文件:train.annotations.pkl。
	# annotations读取了train.annotations.pkl中包含的信息,所以它包含三个键'image_id', 'file_name', 和'caption' 
	counter = Counter()
	max_len = 0
	for i, caption in enumerate(annotations['caption']):
		words = caption.split(' ')
		for w in words:
			counter[w] += 1
		if len(caption.split(" ")) > max_len:
			max_len = len(caption.split(" "))
	
	vocab = [word for word in counter if counter[word] >= threshold]
	# 遍历Counter类实例counter的键,并添加判断条件,符合条件被添加进列表里。dict类实例也有类似功能
	print('Filtered %d words to %d words with word count threshold %d.' % (len(counter), len(vocab), threshold)

	word_to_idx = {u'<NULL>': 0, U'<START>': 1, u'<END>': 2}
	# 这里再度说明idx指的是index。
	idx = 3
	for word in vocab:
		word_to_idx[word] = idx
		idx += 1
	print "Max length of caption: ", max_len
	return word_to_idx

现在我们解读 _build_caption_vector()函数。

def _build_caption_vector(annotations, word_to_idx, max_length=15):
	# 这里的annotations与上面的_build_vocab()函数中的annotations一样,均是处理训练集等到的
	# 不同之处在于这里的annotations,还将处理验证集和测试集
	# 三个键:'image_id', 'file_name', 'caption'
	n_example = len(annotations)
	captions = np.ndarray((n_examples, max_length+2)).astype(np.int32)
	for i, caption in enumerate(annotations['caption']):
		words = caption.split(" ") # caption contrains only lower-case words
		# 将一张图片的一句描述给切分成单词列表
		cap_vec = []
		cap_vec.append(word_to_idx['<START>'])
		for word in words:
			if word in word_to_idx:
				cap_vec.append(word_to_idx[word])
		cap_vec.append(word_to_idx['<NULL>'])
		# 将单词列表转换成索引向量(在前后分别添上<START>和<END>)
		# pad short caption with the special null token '<NULL>' to make it fixed-size vector
		if len(cap_vec) < (max_length + 2):
			for j in range(max_length + 2 - len(cap_vec)):
				cap_vec.append(word_to_idx['<NULL>'])
		
		captions[i, :] = np.asarray(cap_vec)
		# numpy.asarray()可以将列表、元组、元组列表、列表元组转换成numpy数组
	print "Finished building caption vectors"
	return captions		

现在解读_build_file_name()函数

def _build_file_names(annotations):
	# 这里的annotations等同于_build_caption_vector里的annotations
	image_file_names = []
	id_to_idx = {}
	idx = 0
	image_ids = annotations['image_id']
	file_names = annotations['file_name']
	# 由于一张图片有多个描述句子,意味着annotations['image_id']得到的结果里有重复的id信息, annotations['file_name']也有重复的file_name信息,也就意味着image_ids和file_names各自均有重复元素。
	for image_id, file_name in zip(image_ids, file_names):
		if not image_id in id_to_idx:
			id_to_idx[image_id] = idx
			image_file_names.append(file_name)
			idx += 1
	# 上面的for循环,就是起到剔除重复元素的作用,但这时又引入了一个idx,作为id_to_idx中键image_id的值
	file_names = np.asarray(image_file_names)
	return file_names, id_to_idx

最后是_build_image_idxs()函数了

def _build_image_idxs(annotations, id_to_idx):
	# 这里的annotations等同上面函数的annotations。
	image_idxs = np.ndarray(len(annotations), dtype=np.int32)
	image_ids = annotations['image_id']
	# image_idxs和image_ids尺寸大小相同
	for i, image_id in enumerate(image_ids):
		image_idxs[i] = id_to_idx[image_id]
		# 暂时没搞懂多了个复杂的image_idxs有什么好处
	return image_idxs
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值