1204204 最近一直在接触图像生成文本这个东西,看了一些相关的论文,也运行过github上大神们的代码,特别是densecap(密集描述网络),觉得很有意义。到了自己下手准备搭建并一定程度复现框架时,感觉还是有一些棘手。但是总得踏出第一步,理解一个相对简单的框架,并在后续在其基础上进行改变并完善,这样的路线应该会好一点。
环境:python3.6 tensorflow-1.13.1 -gpu
数据集来源 flickr30k-images
(下载链接:https://pan.baidu.com/s/1r0RVUwctJsI0iNuVXHQ6kA密码:hrf3
flickr30k-images.tar是图像,flickr30k.tar.gz是标注)数量为3w张图片
解压完的图片数据集如图所示。
其图片的描述文件在一个token文件中,如下面的格式,每张图有多个标签
1001465944.jpg#2 A young woman walks past two young people dressed in hip black outfits .
1001465944.jpg#3 A woman with a large purse is walking by a gate .
1001465944.jpg#4 Several people standing outside a building .
1001545525.jpg#0 Two men in Germany jumping over a rail at the same time without shirts .
1001545525.jpg#1 Two youths are jumping over a roadside railing , at night .
1001545525.jpg#2 Boys dancing on poles in the middle of the night .
1001545525.jpg#3 Two men with no shirts jumping over a rail .
1001545525.jpg#4 two guys jumping over a gate together
1001573224.jpg#0 Five ballet dancers caught mid jump in a dancing studio with sunlight coming through a window .
1001573224.jpg#1 Ballet dancers in a studio practice jumping with wonderful form .
1001573224.jpg#2 Five girls are leaping simultaneously in a dance practice room .
1001573224.jpg#3 Five girls dancing and bending feet in ballet class .
对于描述,首先会统计一下描述的词表。统计一下长度分布,词频。
#-*- coding: utf-8 -*-
import os
import sys
import pprint
input_descripyion_file='results_20130124.token'
output_vocab_file='vocab.txt'
def count_vocab(input_descripyion_file):
with open(input_descripyion_file,'rb') as f: #读取所有的内容
lines=f.readlines()
max_length_of_sentences=0 #记录最长句子长度
length_dict={} #统计句子长度分布
vocab_dict={} #生成词表
for line in lines:
description= line.decode().strip('n').strip('t') #切分描述
words= description.strip(' ').split() #用空格切分单词
max_length_of_sentences=max(max_length_of_sentences, len(words)) #更新最长的句子长度
length_dict.setdefault(len(words),0)#分布
length_dict[len(words)]+=1
for word in words:#统计词表
vocab_dict.setdefault(word,0)
vocab_dict[word]+=1 #更新词表
print(max_length_of_sentences)
pprint.pprint(length_dict)
return vocab_dict
vocab_dict=count_vocab(input_descripyion_file)
结果为:83
{3: 14,
4: 52,
5: 297,
6: 1109,
7: 3593,
8: 7895,
9: 11070,
10: 13165,
11: 14821,
12: 15427,
13: 14481,
14: 12919,
15: 11394,
16: 9952,.....
最大的句子长度为83,和其他的句子长度的分布。句子长度在30-40之后的句子数量很少,我们可以在训练时忽略这些过长的句子来提高我们训练的效率。
sorted_vocab_dict=sorted(vocab_dict.items(),
key=lambda d:d[1],reverse=True)#对词表进行排序
with open(output_vocab_file,'w',encoding="utf-8") as f:
f.write ('<UNK>t10000000n')
for item in sorted_vocab_dict:
f.write('%st%dn'%item)
并对词表中的词语进行排序,并定义一个unk特殊字符来指代测试时没有遇到过的词,并输出一个vocab的文档文件。
有大约1w个字的出现次数是超过3次的。
接下来我们需要对图片特征进行处理,选择的是inception_v3这个特征提取模型。其下载链接在 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'中。其解压完的结构为
我们需要的文件为文件格式为classify_image_graph_def.pb模型文件。
import os
import sys
import tensorflow as tf
from tensorflow import gfile #文件处理库
from tensorflow import logging
import pprint
#import cPickle
import _pickle as cPickle
import numpy as np
model_file = "inception-2015-12-05/classify_image_graph_def.pb" #模型文件
input_description_file = "results_20130124.token" #图像描述文件 包含图片名和描述
input_img_dir = "flickr30k-images/" #图片文件
output_folder = "feature_extraction_inception_v3" #输出文件
batch_size = 100 # 每次传入图片 参数 将3w张图片分批处理
if not gfile.Exists(output_folder):#判断文件是否存在
gfile.MakeDirs(output_folder)
def parse_token_file(token_file):
#图片描述文件处理
img_name_to_tokens = {} #图片名和描述的对应
with gfile.GFile(token_file, 'r') as f:
lines = f.readlines()#所有行读取
for line in lines:
img_id, description = line.strip('rn').split('t') #切分名和描述
img_name,