论文为Deep Visual-Semantic Alignments for Generating Image Descriptions,是李飞飞在CVPR2015的论文。
实现代码为neuraltalk2,为在GPU上实现。
将代码clone下来,并下载model及测试图,按照步骤即可实现eval。
复现准备使用show and tell的基于tensorflow的模型,im2txt。
目录
目录
/im2txt/ops/image_processing.py
/im2txt/ops/image_embedding.py
/im2txt/ops/image_embedding_test.py
/im2txt/data/build_mscoco_data.py
/im2txt/opt
1.图像预处理脚本
/im2txt/ops/image_processing.py
1.导入包
from __future__ import absolute_import //在 python2.x中用python3.x的特性,表示想对导入中,默认导入系统的包
from __future__ import division //执行的不是截断除法,而是精确除法
from __future__ import print_function //可以使用python3.x的print()
import tensorflow as tf
2.图像变形
疑问:preprocessing threads 是进程数?
def distort_image(image, thread_id):
"""Perform random distortions on an image.
Args:
image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
thread_id: Preprocessing thread id used to select the ordering of color
distortions. There should be a multiple of 2 preprocessing threads.
Returns:
distorted_image: A float32 Tensor of shape [height, width, 3] with values in
[0, 1].
"""
# Randomly flip horizontally.
with tf.name_scope("flip_horizontal", values=[image]):
image = tf.image.random_flip_left_right(image) #tf.image.random_flip_left_right(image)为随机反转函数
# Randomly distort the colors based on thread id.
color_ordering = thread_id % 2
with tf.name_scope("distort_color", values=[image]):
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.) #tf.image.random_brightness是随机亮度调整,调整的范围是[-max_delta,max_delta]
image = tf.image.random_saturation(image, lower=0.5, upper=1.5) #tf.image.random_saturation调整图片饱和度
image = tf.image.random_hue(image, max_delta=0.032) #tf.image.random_hue调整图片色相
image = tf.image.random_contrast(image, lower=0.5, upper=1.5) #tf.image.random_contrast调整对比度
elif color_ordering == 1:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
# The random_* ops do not necessarily clamp.
image = tf.clip_by_value(image, 0.0, 1.0) #tf.clip_by_value 夹断在min , max之间
return image
3.图像处理
问题:为什么是-1,1之间
def process_image(encoded_image,
is_training,
height,
width,
resize_height=346,
resize_width=346,
thread_id=0,
image_format="jpeg"):
def image_summary(name, image):
if not thread_id:
tf.summary.image(name, tf.expand_dims(image, 0)) #当thread_id为0时,用tf.summary.image显示图片
#def image(name, tensor, max_outputs=3, collections=None, family=None)
# tensor的维度为[batch_size,height, width, channels]
# tf.expand_dims(image,axis=) ,0 为在第一位加一维度为1的尺寸,即batch_size
# Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
with tf.name_scope("decode", values=[encoded_image]):
if image_format == "jpeg":
image = tf.image.decode_jpeg(encoded_image, channels=3) #tf.image.decode_jpeg对图像解码,未解码是字符串,解码的是图像像素值
elif image_format == "png":
image = tf.image.decode_png(encoded_image, channels=3)
else:
raise ValueError("Invalid image format: %s" % image_format)
image = tf.image.convert_image_dtype(image, dtype=tf.float32) #将图像转换在[0,1)之间
image_summary("original_image", image)
# Resize image.
assert (resize_height > 0) == (resize_width > 0)
if resize_height: #图像resize
image = tf.image.resize_images(image,
size=[resize_height, resize_width],
method=tf.image.ResizeMethod.BILINEAR)
# Crop to final dimensions.
if is_training:
image = tf.random_crop(image, [height, width, 3]) #随机裁剪图片
else:
# Central crop, assuming resize_height > height, resize_width > width.
image = tf.image.resize_image_with_crop_or_pad(image, height, width) #裁剪货填充
image_summary("resized_image", image)
# Randomly distort the image.
if is_training:
image = distort_image(image, thread_id) #数据增强
image_summary("final_image", image)
# Rescale to [-1,1] instead of [0, 1]
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
2.输入脚本
/im2txt/ops/inputs.py
1.parse句子
def parse_sequence_example(serialized, image_feature, caption_feature):
"""Parses a tensorflow.SequenceExample into an image and caption.
Args:
serialized: A scalar string Tensor; a single serialized SequenceExample.
image_feature: Name of SequenceExample context feature containing image
data.
caption_feature: Name of SequenceExample feature list containing integer
captions.
Returns:
encoded_image: A scalar string Tensor containing a JPEG encoded image.
caption: A 1-D uint64 Tensor with dynamically specified length.
"""
context, sequence = tf.parse_single_sequence_example(
serialized,
context_features={
image_feature: tf.FixedLenFeature([], dtype=tf.string) #定长的tensor,解析定长的输入特征feature相关配置
},
sequence_features={
caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64), #序列项目中的稠密(dense )输入特征的相关配置
})
#tf.parse_single_sequence_example中,serialized为序列化的格式及数据,context_featurs为context_features的格式
#sequence_features为sequence_features的格式
encoded_image = context[image_feature]
caption = sequence[caption_feature]
return encoded_image, caption
2.利用tensorflow中队列和多线程加速
- 队列类型
FIFOQueue()、RandomShuffleQueue()
queue1 = tf.RandomShuffleQueue(...)
queue2 = tf.FIFOQueue(...)
- 出入对列操作
enqueue()、enqueue_many()
dequeueu()、dequeue_many()enqueue_op = queue.enqueue(example)
inputs = queue.dequeue_many(batch_size)
def prefetch_input_data(reader,
file_pattern,
is_training,
batch_size,
values_per_shard,
input_queue_capacity_factor=16,
num_reader_threads=1,
shard_queue_name="filename_queue",
value_queue_name="input_queue"):
data_files = []
for pattern in file_pattern.split(","):
data_files.extend(tf.gfile.Glob(pattern)) #追加到列表
if not data_files:
tf.logging.fatal("Found no input files matching %s", file_pattern)
else:
tf.logging.info("Prefetching values from %d files matching %s",
len(data_files), file_pattern)
if is_training:
filename_queue = tf.train.string_input_producer(
data_files, shuffle=True, capacity=16, name=shard_queue_name) #文件队列
min_queue_examples = values_per_shard * input_queue_capacity_factor
capacity = min_queue_examples + 100 * batch_size
values_queue = tf.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_queue_examples,
dtypes=[tf.string],
name="random_" + value_queue_name)
#tf.RandomShuffleQueue(capacity=10, min_after_dequeue=2, dtypes="float"),假设容量10,最小为2,最多只能输出8个结果
else:
filename_queue = tf.train.string_input_producer(
data_files, shuffle=False, capacity=1, name=shard_queue_name)
capacity = values_per_shard + 3 * batch_size #为什么100变3
values_queue = tf.FIFOQueue(
capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
#
enqueue_ops = []
for _ in range(num_reader_threads):
_, value = reader.read(filename_queue) #读取reader中的值
enqueue_ops.append(values_queue.enqueue([value])) #队列入列,返回一个计算图中的一个Operation节点
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
values_queue, enqueue_ops)) #创建并启动多个队列管理器的多个线程,对每个enqueue_ops都建立一个线程
tf.summary.scalar(
"queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
# tf.summary.scalar用来显示标量信息,一般在画loss,accuary时会用到这个函数
# tf.summary.scalar(tags, values, collections=None, name=None)
# tf.cast执行数据转换
return values_queue
3.(这个没有搞懂作用是什么)
def batch_with_dynamic_pad(images_and_captions,
batch_size,
queue_capacity,
add_summaries=True):
"""Batches input images and captions.
This function splits the caption into an input sequence and a target sequence,
where the target sequence is the input sequence right-shifted by 1. Input and
target sequences are batched and padded up to the maximum length of sequences
in the batch. A mask is created to distinguish real words from padding words.
Example:
Actual captions in the batch ('-' denotes padded character):
[
[ 1 2 3 4 5 ],
[ 1 2 3 4 - ],
[ 1 2 3 - - ],
]
input_seqs:
[
[ 1 2 3 4 ],
[ 1 2 3 - ],
[ 1 2 - - ],
]
target_seqs:
[
[ 2 3 4 5 ],
[ 2 3 4 - ],
[ 2 3 - - ],
]
mask:
[
[ 1 1 1 1 ],
[ 1 1 1 0 ],
[ 1 1 0 0 ],
]
Args:
images_and_captions: A list of pairs [image, caption], where image is a
Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
any length. Each pair will be processed and added to the queue in a
separate thread.
batch_size: Batch size.
queue_capacity: Queue capacity.
add_summaries: If true, add caption length summaries.
Returns:
images: A Tensor of shape [batch_size, height, width, channels].
input_seqs: An int32 Tensor of shape [batch_size, padded_length].
target_seqs: An int32 Tensor of shape [batch_size, padded_length].
mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
"""
enqueue_list = []
for image, caption in images_and_captions: #images_And_captions里是什么
caption_length = tf.shape(caption)[0]
input_length = tf.expand_dims(tf.subtract(caption_length, 1), 0)
input_seq = tf.slice(caption, [0], input_length)
target_seq = tf.slice(caption, [1], input_length)
indicator = tf.ones(input_length, dtype=tf.int32)
enqueue_list.append([image, input_seq, target_seq, indicator])
images, input_seqs, target_seqs, mask = tf.train.batch_join(
enqueue_list,
batch_size=batch_size,
capacity=queue_capacity,
dynamic_pad=True,
name="batch_and_pad")
if add_summaries:
lengths = tf.add(tf.reduce_sum(mask, 1), 1) #tf.reduce_sum(mask, 1)按列求和
tf.summary.scalar("caption_length/batch_min", tf.reduce_min(lengths))
tf.summary.scalar("caption_length/batch_max", tf.reduce_max(lengths))
tf.summary.scalar("caption_length/batch_mean", tf.reduce_mean(lengths))
return images, input_seqs, target_seqs, mask
3. image_embedding.py
/im2txt/ops/image_embedding.py
这个文件中调用了inception_v3
要使arg_scope成功跑起来需要两个步骤:
- 用
@add_arg_scope
修饰目标函数 - 用
with arg_scope(...)
设置默认参数.
net的shape?
tf.truncated_normal_initializer 从截断的正态分布中输出随机值。
生成的值服从具有指定平均值和标准偏差的正态分布,如果生成的值大于平均值2个标准偏差的值则丢弃重新选择。
为了使代码更简洁,采用arg_scope()定义,在这里定义的层参数一致,不同的可以下面分别定义。
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
slim = tf.contrib.slim
def inception_v3(images,
trainable=True,
is_training=True,
weight_decay=0.00004,
stddev=0.1,
dropout_keep_prob=0.8,
use_batch_norm=True,
batch_norm_params=None,
add_summaries=True,
scope="InceptionV3"):
"""Builds an Inception V3 subgraph for image embeddings.
Args:
images: A float32 Tensor of shape [batch, height, width, channels].
trainable: Whether the inception submodel should be trainable or not.
is_training: Boolean indicating training mode or not.
weight_decay: Coefficient for weight regularization.
stddev: The standard deviation of the trunctated normal weight initializer.
dropout_keep_prob: Dropout keep probability.
use_batch_norm: Whether to use batch normalization.
batch_norm_params: Parameters for batch normalization. See
tf.contrib.layers.batch_norm for details.
add_summaries: Whether to add activation summaries.
scope: Optional Variable scope.
Returns:
end_points: A dictionary of activations from inception_v3 layers.
"""
# Only consider the inception model to be in training mode if it's trainable.
is_inception_model_training = trainable and is_training
if use_batch_norm:
# Default parameters for batch normalization.
if not batch_norm_params:
batch_norm_params = {
"is_training": is_inception_model_training,
"trainable": trainable,
# Decay for the moving averages.
"decay": 0.9997,
# Epsilon to prevent 0s in variance.
"epsilon": 0.001,
# Collection containing the moving mean and moving variance.
"variables_collections": {
"beta": None,
"gamma": None,
"moving_mean": ["moving_vars"], #moving_vars是啥东西
"moving_variance": ["moving_vars"],
}
}
else:
batch_norm_params = None
if trainable:
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) #tf.contrib.layers.l2_regularizer(lambda)(weight_decay)
else:
weights_regularizer = None
with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
# tf.variable_scope(self, name_or_scope,default_name=None,values)
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
trainable=trainable):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net, end_points = inception_v3_base(images, scope=scope)
# end_points是外部需要调用的参数,比如loss或者其他,net是网络结构
# 这里得到的net是softmax分类之前的结构
# inception_v3_base()定义了 inception 网络从输入到输出前面一层的网络结构,
with tf.variable_scope("logits"): #logits 未进入softmax的概率,variable_scope实现参数共享
shape = net.get_shape() #[channels, height, width]
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool") #height width
net = slim.dropout(
net,
keep_prob=dropout_keep_prob,
is_training=is_inception_model_training,
scope="dropout")
net = slim.flatten(net, scope="flatten") #扁平化
# Add summaries.
if add_summaries:
for v in end_points.values():
tf.contrib.layers.summaries.summarize_activation(v) #这句话的意思应该是将参数记下来
return net
4. image_embedding_test.py
/im2txt/ops/image_embedding_test.py
对3. image_embedding.py的测试
class InceptionV3Test(tf.test.TestCase):
def setUp(self):
super(InceptionV3Test, self).setUp() #这句话好像很多class一开始的setup里都有
batch_size = 4
height = 299
width = 299
num_channels = 3
self._images = tf.placeholder(tf.float32,
[batch_size, height, width, num_channels])
self._batch_size = batch_size
def _countInceptionParameters(self):
"""Counts the number of parameters in the inception model at top scope."""
counter = {}
for v in tf.global_variables():
name_tokens = v.op.name.split("/")
if name_tokens[0] == "InceptionV3":
name = "InceptionV3/" + name_tokens[1]
num_params = v.get_shape().num_elements() #统计相同名字的个数
assert num_params
counter[name] = counter.get(name, 0) + num_params
return counter
def _verifyParameterCounts(self): #对比数量是否正确
"""Verifies the number of parameters in the inception model."""
param_counts = self._countInceptionParameters()
expected_param_counts = {
"InceptionV3/Conv2d_1a_3x3": 960,
"InceptionV3/Conv2d_2a_3x3": 9312,
"InceptionV3/Conv2d_2b_3x3": 18624,
"InceptionV3/Conv2d_3b_1x1": 5360,
"InceptionV3/Conv2d_4a_3x3": 138816,
"InceptionV3/Mixed_5b": 256368,
"InceptionV3/Mixed_5c": 277968,
"InceptionV3/Mixed_5d": 285648,
"InceptionV3/Mixed_6a": 1153920,
"InceptionV3/Mixed_6b": 1298944,
"InceptionV3/Mixed_6c": 1692736,
"InceptionV3/Mixed_6d": 1692736,
"InceptionV3/Mixed_6e": 2143872,
"InceptionV3/Mixed_7a": 1699584,
"InceptionV3/Mixed_7b": 5047872,
"InceptionV3/Mixed_7c": 6080064,
}
self.assertDictEqual(expected_param_counts, param_counts)
def _assertCollectionSize(self, expected_size, collection):
actual_size = len(tf.get_collection(collection)) # tf.get_collection(collection)返回当前计算图中张量的集合
if expected_size != actual_size:
self.fail("Found %d items in collection %s (expected %d)." %
(actual_size, collection, expected_size))
tf.GraphKeys.GLOBAL_VARIABLES
- 使用tf.get_variable()时,默认将variable放入这个集合
TRAINABLE_VARIABLES
- 将由优化器训练的变量对象的子集
REGULARIZATION_LOSSES
- 在图形构造期间收集的正规化损失
SUMMARIES
- 在关系图中创建的汇总张量对象。
这段可以再斟酌一下!!
def testTrainableTrueIsTrainingTrue(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=True, is_training=True)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableTrueIsTrainingFalse(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=True, is_training=False)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableFalseIsTrainingTrue(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=False, is_training=True)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES) #tf.GraphKeys.GLOBAL_VARIABLES
self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableFalseIsTrainingFalse(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=False, is_training=False)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.GLOBAL_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
/im2txt/data/
5.build_mscoco_data.py
/im2txt/data/build_mscoco_data.py
TFRecord:是谷歌推荐的一种二进制文件格式
The SequenceExample proto contains the following fields:
context:
image/image_id: integer MSCOCO image identifier
image/data: string containing JPEG encoded image in RGB colorspace
feature_lists:
image/caption: list of strings containing the (tokenized) caption words
image/caption_ids: list of integer ids corresponding to the caption words
NLTK: natural language toolkit
作用:
- nlkt.text.Text()类是用于对文本进行初级的统计与分析,它接受一个词的列表作为参数。
- nltk.text.TextCollection()类是Text的集合,提供下列方法。
1.
from collections import Counter #计数器,统计个数
from collections import namedtuple #nametuple用来创建一个自定义的图片了对象,并且规定了元素的个数 e.g nametuple('Point',['x','y']), p = Point(1,2)
from datetime import datetime
import json
import os.path
import random
import sys
import threading #threading模块用来创建线程,直接从threading.Thread继承
import nltk.tokenize # 将大部件切成小部件
import numpy as np
from six.moves import xrange # xrange生成生成器,list(xrange(3,5))可以用来生产list
import tensorflow as tf
2.
tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
"Training image directory.")
tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
"Validation image directory.")
tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
"Training captions JSON file.")
tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_val2014.json",
"Validation captions JSON file.")
tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
tf.flags.DEFINE_integer("train_shards", 256,
"Number of shards in training TFRecord files.")
tf.flags.DEFINE_integer("val_shards", 4,
"Number of shards in validation TFRecord files.")
tf.flags.DEFINE_integer("test_shards", 8,
"Number of shards in testing TFRecord files.")
tf.flags.DEFINE_string("start_word", "<S>",
"Special word added to the beginning of each sentence.")
tf.flags.DEFINE_string("end_word", "</S>",
"Special word added to the end of each sentence.")
tf.flags.DEFINE_string("unknown_word", "<UNK>",
"Special word meaning 'unknown'.")
tf.flags.DEFINE_integer("min_word_count", 4,
"The minimum number of occurrences of each word in the "
"training set for inclusion in the vocabulary.")
tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
"Output vocabulary file of word counts.")
tf.flags.DEFINE_integer("num_threads", 8,
"Number of threads to preprocess the images.")
FLAGS = tf.flags.FLAGS
ImageMetadata = namedtuple("ImageMetadata",
["image_id", "filename", "captions"])
3.
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self, vocab, unk_id):
"""Initializes the vocabulary.
Args:
vocab: A dictionary of word to word_id.
unk_id: Id of the special 'unknown' word.
"""
self._vocab = vocab
self._unk_id = unk_id
def word_to_id(self, word):
"""Returns the integer id of a word string."""
if word in self._vocab:
return self._vocab[word]
else:
return self._unk_id
4.
class ImageDecoder(object):
"""Helper class for decoding images in TensorFlow."""
def __init__(self):
# Create a single TensorFlow Session for all image decoding calls.
self._sess = tf.Session()
# TensorFlow ops for JPEG decoding.
self._encoded_jpeg = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
def decode_jpeg(self, encoded_jpeg):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._encoded_jpeg: encoded_jpeg})
#self._sess.run()为Session().run(output,feed_dict),
#这里的复制只在这里有效n,这里的作用就是赋值给image
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
参考:https://zhuanlan.zhihu.com/p/40588218
要将我们的数据写入 .tfrecords 文件,需要将每一个样本数据封装为tf.train.Example格式,再将Example逐个写入文件。Example格式中的数据基础类型是tf.train.Feature。tf.train.的格式有BytesList, FloatList, Int64List
tf.train.Feature()的参数有byte_list,float_list,int64_list。
- tf.train.Feature(bytes_list=tf.train.BytesList(value=[])
- tf.train.Feature(int64_list=tf.train.Int64List(value=[])
- tf.train.Feature(float_list=tf.train.FloatList(value=[])
tf.train.FeatureList()的参数是tf.train.Feature()的list
- tf.train.FeatureList(feature=[tf.train.Feature(bytes_list=tf.train.BytesList(value=[])])
- tf.train.FeatureList(feature=[])
- tf.train.FeatureList(feature=[])
5.
def _int64_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""Wrapper for inserting a bytes Feature into a SequenceExample proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
def _int64_feature_list(values):
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
def _bytes_feature_list(values):
"""Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
tf.gfile.FastGFile是tf提高的一个导入图像的函数,但是是未解码的图像,需要加入tf.image.decode_jpeg()
tf.train.Features: 它的参数是一个字典,k-v对中 v 的类型是Feature,对应每一个字段。
- tf.train.Features('k':bytes_list=tf.train.BytesList(value=[],'v':bytes_list=tf.train.BytesList(value=[])
- tf.train.Features('k':int64_list=tf.train.Int64List(value=[],'v':bytes_list=tf.train.BytesList(value=[])
- tf.train.Feature('k':float_list=tf.train.FloatList(value=[],'v':bytes_list=tf.train.BytesList(value=[])
'''tfrecord 写入序列数据,每个样本的长度不固定。
和固定 shape 的数据处理方式类似,前者使用 tf.train.Example() 方式,而对于变长序列数据,需要使用
tf.train.SequenceExample()。在 tf.train.SequenceExample() 中,又包括了两部分:context 来放置非序列化部分;feature_lists 放置变长序列。
6.
def _to_sequence_example(image, decoder, vocab):
"""Builds a SequenceExample proto for an image-caption pair.
Args:
image: An ImageMetadata object.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
Returns:
A SequenceExample proto.
"""
with tf.gfile.FastGFile(image.filename, "r") as f:
encoded_image = f.read()
try:
decoder.decode_jpeg(encoded_image)
except (tf.errors.InvalidArgumentError, AssertionError):
print("Skipping file with invalid JPEG data: %s" % image.filename)
return
context = tf.train.Features(feature={
"image/image_id": _int64_feature(image.image_id),
"image/data": _bytes_feature(encoded_image),
})
assert len(image.captions) == 1
caption = image.captions[0]
caption_ids = [vocab.word_to_id(word) for word in caption]
feature_lists = tf.train.FeatureLists(feature_list={
"image/caption": _bytes_feature_list(caption),
"image/caption_ids": _int64_feature_list(caption_ids)
})
sequence_example = tf.train.SequenceExample(
context=context, feature_lists=feature_lists)
# content是非序列化的成分,feature_lists放变长序列
return sequence_example
此段可以再斟酌一下
问题:
- ranges的shape?
- 为什么 j+ 1
- shards的作用,num_shards作用
刷新输出:sys.stdout.flush()
7. 处理图片文件
def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
num_shards):
"""Processes and saves a subset of images as TFRecord files in one thread.
Args:
thread_index: Integer thread identifier within [0, len(ranges)].
ranges: A list of pairs of integers specifying the ranges of the dataset to
process in parallel.
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Each thread produces N shards where N = num_shards / num_threads. For
# instance, if num_shards = 128, and num_threads = 2, then the first thread
# would produce shards [0, 64).
num_threads = len(ranges) #线程个数
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
# np.linspace(start, stop, num)
# 为什么 j+ 1
num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in images_in_shard:
image = images[i]
sequence_example = _to_sequence_example(image, decoder, vocab)
if sequence_example is not None:
writer.write(sequence_example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print("%s [thread %d]: Processed %d of %d items in thread batch." %
(datetime.now(), thread_index, counter, num_images_in_thread))
sys.stdout.flush()
writer.close()
print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
(datetime.now(), thread_index, counter, num_shards_per_batch))
sys.stdout.flush()
8.处理数据集
这个函数调用了 上面的函数;
Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
问题:
传进来的num_shards是什么
def _process_dataset(name, images, vocab, num_shards):
"""Processes a complete data set and saves it as a TFRecord.
Args:
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Break up each image into a separate entity for each caption.
images = [ImageMetadata(image.image_id, image.filename, [caption])
for image in images for caption in image.captions]
# Shuffle the ordering of images. Make the randomization repeatable.
random.seed(12345)
random.shuffle(images)
# Break the images into num_threads batches. Batch i is defined as
# images[ranges[i][0]:ranges[i][1]].
num_threads = min(num_shards, FLAGS.num_threads)
spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# 每个threads的间隔对
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a utility for decoding JPEG images to run sanity checks.
decoder = ImageDecoder()
# Launch a thread for each batch.
print("Launching %d threads for spacings: %s" % (num_threads, ranges))
for thread_index in xrange(len(ranges)):
# thread_index为images间隔对的index
# ranges 为间隔对
# name 为名字 train\val\test
# images为包括image和对应caption的数据,可打印输出
# decoder是image解码器
# vocab为词典
# num_shards为输出文件的碎片Integer number of shards for the output files.
args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
t = threading.Thread(target=_process_image_files, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
(datetime.now(), len(images), name))
9.创建词典
这里用的是因为词,分词可以用jieba做
def _create_vocab(captions):
"""Creates the vocabulary of word to word_id.
The vocabulary is saved to disk in a text file of word counts. The id of each
word in the file is its corresponding 0-based line number.
Args:
captions: A list of lists of strings.
Returns:
A Vocabulary object.
"""
print("Creating vocabulary.")
counter = Counter()
for c in captions:
counter.update(c)
print("Total words:", len(counter))
# Filter uncommon words and sort by descending count.
word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
word_counts.sort(key=lambda x: x[1], reverse=True)
print("Words in vocabulary:", len(word_counts))
# 包括词与其个数
# Write out the word counts file.
with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
# Create the vocabulary dictionary.
reverse_vocab = [x[0] for x in word_counts]
unk_id = len(reverse_vocab) #输出看一下结果
vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
vocab = Vocabulary(vocab_dict, unk_id)
return vocab
10.将一串caption词加上开始和结束符,并lower字符
def _process_caption(caption):
"""Processes a caption string into a list of tonenized words.
Args:
caption: A string caption.
Returns:
A list of strings; the tokenized caption.
"""
tokenized_caption = [FLAGS.start_word]
tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
tokenized_caption.append(FLAGS.end_word)
return tokenized_caption
11.从json中读取数据,并放到ImageMetadata中
def _load_and_process_metadata(captions_file, image_dir):
"""Loads image metadata from a JSON file and processes the captions.
Args:
captions_file: JSON file containing caption annotations.
image_dir: Directory containing the image files.
Returns:
A list of ImageMetadata.
"""
with tf.gfile.FastGFile(captions_file, "r") as f:
caption_data = json.load(f)
# Extract the filenames.
id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
# Extract the captions. Each image_id is associated with multiple captions.
id_to_captions = {}
for annotation in caption_data["annotations"]:
image_id = annotation["image_id"]
caption = annotation["caption"]
id_to_captions.setdefault(image_id, [])
# id_to_captions.setdefault(key, default=None), 如果key存在,则返回value;如不存在,则返回None
id_to_captions[image_id].append(caption)
assert len(id_to_filename) == len(id_to_captions)
assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
print("Loaded caption metadata for %d images from %s" %
(len(id_to_filename), captions_file))
# Process the captions and combine the data into a list of ImageMetadata.
print("Processing captions.")
image_metadata = []
num_captions = 0
for image_id, base_filename in id_to_filename:
filename = os.path.join(image_dir, base_filename)
captions = [_process_caption(c) for c in id_to_captions[image_id]]
# 对于每个id的每个caption进行分词和加上start 和 ends值
image_metadata.append(ImageMetadata(image_id, filename, captions))
# 生成对应的额image_metadata,["image_id", "filename", "captions"]
num_captions += len(captions)
print("Finished processing %d captions for %d images in %s" %
(num_captions, len(id_to_filename), captions_file))
return image_metadata
10.main函数
def main(unused_argv):
def _is_valid_num_shards(num_shards):
"""Returns True if num_shards is compatible with FLAGS.num_threads."""
return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
assert _is_valid_num_shards(FLAGS.train_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
assert _is_valid_num_shards(FLAGS.val_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
assert _is_valid_num_shards(FLAGS.test_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
if not tf.gfile.IsDirectory(FLAGS.output_dir):
# 判断如果不是dirz,则建立新的dir
tf.gfile.MakeDirs(FLAGS.output_dir)
# Load image metadata from caption files.
mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
FLAGS.train_image_dir)
mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
FLAGS.val_image_dir)
# Redistribute the MSCOCO data as follows:
# train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
# val_dataset = 5% of mscoco_val_dataset (for validation during training).
# test_dataset = 10% of mscoco_val_dataset (for final evaluation).
train_cutoff = int(0.85 * len(mscoco_val_dataset))
val_cutoff = int(0.90 * len(mscoco_val_dataset))
train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
test_dataset = mscoco_val_dataset[val_cutoff:]
# Create vocabulary from the training captions.
train_captions = [c for image in train_dataset for c in image.captions]
vocab = _create_vocab(train_captions)
_process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
_process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
_process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
if __name__ == "__main__":
tf.app.run() # main函数的运行入口
6.BUILD
/im2txt/BUILD
deps:依赖项
name:名字属性,和srcs名字一样
srcs:它的值是标签列表,py文件名
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//im2txt/...",
],
)
py_binary(
name = "build_mscoco_data",
srcs = [
"data/build_mscoco_data.py",
],
)
sh_binary(
name = "download_and_preprocess_mscoco",
srcs = ["data/download_and_preprocess_mscoco.sh"],
data = [
":build_mscoco_data",
],
)
py_library(
name = "configuration",
srcs = ["configuration.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "show_and_tell_model",
srcs = ["show_and_tell_model.py"],
srcs_version = "PY2AND3",
deps = [
"//im2txt/ops:image_embedding",
"//im2txt/ops:image_processing",
"//im2txt/ops:inputs",
],
)
py_test(
name = "show_and_tell_model_test",
size = "large",
srcs = ["show_and_tell_model_test.py"],
deps = [
":configuration",
":show_and_tell_model",
],
)
py_library(
name = "inference_wrapper",
srcs = ["inference_wrapper.py"],
srcs_version = "PY2AND3",
deps = [
":show_and_tell_model",
"//im2txt/inference_utils:inference_wrapper_base",
],
)
py_binary(
name = "train",
srcs = ["train.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "run_inference",
srcs = ["run_inference.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":inference_wrapper",
"//im2txt/inference_utils:caption_generator",
"//im2txt/inference_utils:vocabulary",
],
)
/im2txt/inference_utils
7.BUILD
package(default_visibility = ["//im2txt:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "inference_wrapper_base",
srcs = ["inference_wrapper_base.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "vocabulary",
srcs = ["vocabulary.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "caption_generator",
srcs = ["caption_generator.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "caption_generator_test",
srcs = ["caption_generator_test.py"],
deps = [
":caption_generator",
],
)
8.caption_generator.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import heapq #堆
import math
import numpy as np
1. 初始化
class Caption(object):
"""Represents a complete or partial caption."""
def __init__(self, sentence, state, logprob, score, metadata=None):
"""Initializes the Caption.
Args:
sentence: List of word ids in the caption.
state: Model state after generating the previous word.
logprob: Log-probability of the caption.
score: Score of the caption.
metadata: Optional metadata associated with the partial sentence. If not
None, a list of strings with the same length as 'sentence'.
"""
self.sentence = sentence
self.state = state
self.logprob = logprob
self.score = score
self.metadata = metadata
2.分数比较
def __cmp__(self, other):
"""Compares Captions by score."""
assert isinstance(other, Caption)
if self.score == other.score:
return 0
elif self.score < other.score:
return -1
else:
return 1
# For Python 3 compatibility (__cmp__ is deprecated).
def __lt__(self, other):
assert isinstance(other, Caption)
return self.score < other.score
# Also for Python 3 compatibility.
def __eq__(self, other):
assert isinstance(other, Caption)
return self.score == other.score
3.TopN
class TopN(object):
"""Maintains the top n elements of an incrementally provided set."""
def __init__(self, n):
self._n = n
self._data = []
def size(self):
assert self._data is not None
return len(self._data)
def push(self, x):
"""Pushes a new element."""
assert self._data is not None
if len(self._data) < self._n:
heapq.heappush(self._data, x)
else:
heapq.heappushpop(self._data, x)
def extract(self, sort=False):
"""Extracts all elements from the TopN. This is a destructive operation.
The only method that can be called immediately after extract() is reset().
Args:
sort: Whether to return the elements in descending sorted order.
Returns:
A list of data; the top n elements provided to the set.
"""
assert self._data is not None
data = self._data
self._data = None
if sort:
data.sort(reverse=True)
return data
def reset(self):
"""Returns the TopN to an empty state."""
self._data = []
4.CaptionGenerator
class CaptionGenerator(object):
"""Class to generate captions from an image-to-text model."""
def __init__(self,
model,
vocab,
beam_size=3,
max_caption_length=20,
length_normalization_factor=0.0):
"""Initializes the generator.
Args:
model: Object encapsulating a trained image-to-text model. Must have
methods feed_image() and inference_step(). For example, an instance of
InferenceWrapperBase.
vocab: A Vocabulary object.
beam_size: Beam size to use when generating captions.
max_caption_length: The maximum caption length before stopping the search.
length_normalization_factor: If != 0, a number x such that captions are
scored by logprob/length^x, rather than logprob. This changes the
relative scores of captions depending on their lengths. For example, if
x > 0 then longer captions will be favored.
"""
self.vocab = vocab
self.model = model
self.beam_size = beam_size
self.max_caption_length = max_caption_length
self.length_normalization_factor = length_normalization_factor
Beam Search:集束搜索
Beam Search(集束搜索)是一种启发式图搜索算法,通常用在图的解空间比较大的情况下,为了减少搜索所占用的空间和时间,在每一步深度扩展的时候,剪掉一些质量比较差的结点,保留下一些质量较高的结点。
BeamSearch示例:
test的时候,假设词表大小(beam size)为3,内容为a,b,c。
1: 生成第1个词的时候,选择概率最大的2个词,假设为a,c,那么当前序列就是a,c
2:生成第2个词的时候,我们将当前序列a和c,分别与词表中的所有词进行组合,得到新的6个序列aa ab ac ca cb cc,然后从其中选择2个得分最高的,作为当前序列,假如为aa cb
3:后面会不断重复这个过程,直到遇到结束符为止。最终输出2个得分最高的序列。
9.inference_wrapper_base.py
问题:
这里build_model没有用? #这里的build_model在inference_wrapper.py中定义
1.将checkpoint_path里的训练模型恢复到sess中
def _create_restore_fn(self, checkpoint_path, saver):
"""Creates a function that restores a model from checkpoint.
Args:
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
saver: Saver for restoring variables from the checkpoint file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
Raises:
ValueError: If checkpoint_path does not refer to a checkpoint file or a
directory containing a checkpoint file.
"""
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
# 获取最后一次保存的模型
if not checkpoint_path:
raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
def _restore_fn(sess):
tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
saver.restore(sess, checkpoint_path)
# 把checkpoint_path的model恢复到sess中
tf.logging.info("Successfully loaded checkpoint: %s",
os.path.basename(checkpoint_path))
return _restore_fn
2.将checkpoint_path里的训练模型恢复到saver中
def build_graph_from_config(self, model_config, checkpoint_path):
"""Builds the inference graph from a configuration object.
Args:
model_config: Object containing configuration for building the model.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
tf.logging.info("Building model.")
self.build_model(model_config)
saver = tf.train.Saver()
return self._create_restore_fn(checkpoint_path, saver)
3.
问题:
Load the Graph的定义没有用到?
def build_graph_from_proto(self, graph_def_file, saver_def_file,
checkpoint_path):
"""Builds the inference graph from serialized GraphDef and SaverDef protos.
Args:
graph_def_file: File containing a serialized GraphDef proto.
saver_def_file: File containing a serialized SaverDef proto.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
# Load the Graph.
tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_def_file, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
# Load the Saver.
tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
saver_def = tf.train.SaverDef()
with tf.gfile.FastGFile(saver_def_file, "rb") as f:
saver_def.ParseFromString(f.read())
saver = tf.train.Saver(saver_def=saver_def)
return self._create_restore_fn(checkpoint_path, saver)
4.在别的函数定义的接口
def feed_image(self, sess, encoded_image)
def inference_step(self, sess, input_feed, state_feed)
10.caption_generator_test.py
测试caption_generator.py的脚本
1. 创建一个假字典
class FakeVocab(object):
"""Fake Vocabulary for testing purposes."""
def __init__(self):
self.start_id = 0 # Word id denoting sentence start.
self.end_id = 1 # Word id denoting sentence end.
...
11.vocabulary.py
class Vocabulary(object):
"""Vocabulary class for an image-to-text model."""
def __init__(self,
vocab_file,
start_word="<S>",
end_word="</S>",
unk_word="<UNK>"):
"""Initializes the vocabulary.
Args:
vocab_file: File containing the vocabulary, where the words are the first
whitespace-separated token on each line (other tokens are ignored) and
the word ids are the corresponding line numbers.
start_word: Special word denoting sentence start.
end_word: Special word denoting sentence end.
unk_word: Special word denoting unknown words.
"""
if not tf.gfile.Exists(vocab_file):
tf.logging.fatal("Vocab file %s not found.", vocab_file)
tf.logging.info("Initializing vocabulary from file: %s", vocab_file)
with tf.gfile.GFile(vocab_file, mode="r") as f:
reverse_vocab = list(f.readlines())
reverse_vocab = [line.split()[0] for line in reverse_vocab]
assert start_word in reverse_vocab
assert end_word in reverse_vocab
if unk_word not in reverse_vocab:
reverse_vocab.append(unk_word)
vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
tf.logging.info("Created vocabulary with %d words" % len(vocab))
self.vocab = vocab # vocab[word] = id
self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word
# Save special word ids.
self.start_id = vocab[start_word]
self.end_id = vocab[end_word]
self.unk_id = vocab[unk_word]
def word_to_id(self, word):
"""Returns the integer word id of a word string."""
if word in self.vocab:
return self.vocab[word]
else:
return self.unk_id
def id_to_word(self, word_id):
"""Returns the word string of an integer word id."""
if word_id >= len(self.reverse_vocab):
return self.reverse_vocab[self.unk_id]
else:
return self.reverse_vocab[word_id]
12.inference_wrapper.py
调用接口文件
class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
"""Model wrapper class for performing inference with a ShowAndTellModel."""
def __init__(self):
super(InferenceWrapper, self).__init__()
def build_model(self, model_config):
model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
model.build()
return model
def feed_image(self, sess, encoded_image):
initial_state = sess.run(fetches="lstm/initial_state:0",
feed_dict={"image_feed:0": encoded_image})
return initial_state
def inference_step(self, sess, input_feed, state_feed):
softmax_output, state_output = sess.run(
fetches=["softmax:0", "lstm/state:0"],
feed_dict={
"input_feed:0": input_feed,
"lstm/state_feed:0": state_feed,
})
return softmax_output, state_output, None
\im2txt\
13.inference_wrapper.py
问题:
self.values_per_input_shard = 2300;
self.input_queue_capacity_factor = 2
解决梯度爆炸问题的方法
- 首先设置一个梯度阈值:clip_gradient
- 在后向传播中求出各参数的梯度,这里我们不直接使用梯度进去参数更新,我们求这些梯度的l2范数
- 然后比较梯度的l2范数||g||与clip_gradient的大小
- 如果前者大,求缩放因子clip_gradient/||g||, 由缩放因子可以看出梯度越大,则缩放因子越小,这样便很好地控制了梯度的范围
- 最后将梯度乘上缩放因子便得到最后所需的梯度
14.run_interference.py
tf.Graph()参考
tf.Graph().as_default() 表示将这个类实例,也就是新生成的图作为整个 tensorflow 运行环境的默认图,如果只有一个主线程不写也没有关系
tf.Graph.finalize()表示图构建完毕,将其设为只读。
15.show_and_tell_model.py
问题:
input_mask
inference mode 是做什么的
Dequeue()参考
dynamic_rnn实现的功能就是可以让不同迭代传入的batch可以是长度不同数据,但同一次迭代一个batch内部的所有数据长度仍然是固定的
16.train.py
tf.train.exponential_decay函数(指数衰减法)