tensorflow实现 Show and Tell: A Neural Image Caption Generator 模型

"""
设计步骤
1. Data generaor:训练模型时提供数据
    a. load vocab
    b. load image feature
    c.provide data for training
    
2.Build image caption model
3.Trains the model
"""

import os 
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import pickle
import numpy as np
import numpy
#输入描述文件
input_description_file=""
input_img_feature_dir=""
#生成的词表文件
input_vocab_file=""
output_dir=""

if not gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)
    

#定义模型所需参数
def get_default_params():
    return tf.contrib.training.HParams(
        #此表过滤参数,设置一个阈值
        num_vocab_word_threshold=3,
       #lstm的结构参数
        num_embedding_nodes=32,
        num_timesteps=10,
        num_lstm_nodes=[64,64],
        num_lstm_layers=2,
        num_fc_nodes=32,
        batch_size=80,
        cell_type="lstm",
        #梯度剪切
        clip_lstm_grads=1.0,
        learning_rate=0.001,
        keep_prob=0.8,
        #每隔多久打印一次
        log_frequent=100,
        #每隔多久保存一次
        save_frequent=1000,
    )

hps=get_default_params()
    
#词表载入
class Vocab(object):
    def __init__(self,filename,word_num_threshold):
        self._id_to_word={}
        self._word_to_id={}
        self._unk=-1
        #句子结束符
        self._eos=-1
        self._word_num_threshold=word_num_threshold
        self._read_dict()
    
    def _read_dict(self,filename):
        with gfile.GFile(filename,'r') as f:
            lines=f.readlines()
        for line in lines:
            word,occurrence=line.strip('\r\n').split('\t')
            occurrence=int(occurrence)
            if occurrence<self._word_num_threshold:
                continue
                idx=len(self._id_to_word)
                if word=='<UNK>':
                    self._unk=idx
                elif word=='.':
                    self._eos=idx
                if word in self._word_to_id or idx in self._id_to_word:
                    raise Exception("")
                    self._word_to_id[word]=idx
                    self._id_to_word[idx]=word
    
    @property
    def unk(self):
        return self._unk
    
    @property
    def eos(self):
        return self._eos
    
    def id_to_word(self,word_id):
        return self._id_to_word.get(word_id,'<UNK>')
    
    def word_to_id(self,word):
        return self._word_to_id.get(word,self.unk)
    
    def size(self):
        return len(self._id_to_word)
    
    def encode(self,sentence):
        return [self.word_to_id(word) for word in sentence.split(' ')]
    def decode(self,sentence_id):
        words= [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)
     
vocab=Vocab(input_vocab_file,hps.num_vocab_word_threshold)
vocab_size=vocab.size()
print(vocab_size)
    
def parse_token_file(token_file):
    img_name_to_tokens={}
    with gfile.GFile(token_file,'r') as f:
        lines=f.readlines()
        
    for line in lines:
        img_id,description =line.strip('\r\n').split('\t')
        img_name,_=img_id.split('#')
        img_name_to_tokens.setdefault(img_name,[])
        img_name_to_tokens[img_name].append(description)
    return img_name_to_tokens

#将每张图片的每个描述转换成id
def convert_token_to_id(img_name_to_tokens,vocab):
    img_name_to_tokens_id={}
    for img_name in img_name_to_tokens:
        img_name_to_tokens_id.setdefault(img_name,[])
        for description in img_name_to_tokens[img_name]:
            token_ids=vocab.encode(description)
            img_name_to_tokens_id[img_name].append(token_ids)
    return img_name_to_tokens_id


img_name_to_tokens=parse_token_file(input_description_file)
img_name_to_tokens_id=convert_token_to_id(img_name_to_tokens,vocab)

logging.info("num of all images"%len(img_name_to_tokens))
#给模型提供数据
class ImageCaptionData(object):
    #初始化
    def __init__(self,img_name_to_tokens_id,img_feature_dir,num_timesteps,
                 vocab,
                 deterministic =False):
        self._vocab=vocab
        self._img_name_to_tokens_id=img_name_to_tokens_id
        self._num_timesteps=num_timesteps
        self._deterministic=deterministic
        self._indicator=0
        
        self._img_feature_filenames=[]
        self._img_feature_data=[]
        
        self._all_img_feature_filepaths=[]
        for filename in gfile.ListDirectory(img_feature_dir):
            self._all_img_feature_filepaths.append(
            os.path.join(img_feature_dir,filename))
        pprint.pprint(self._all_img_feature_filepaths)
        self._load_img_feature_pickle()
        
        if not self._deterministic:
            self._random_shuffle()
            #从pickle文件中加载数据
    def _load_img_feature_pickle(self):
        for filepath in self._all_img_feature_filepaths:
            logging.info("loading %s" % filepath)
            with gfile.GFile(filepath,'r') as f:
                filenames,features=pickle.load(f)
                #列表
                self._img_feature_filenames+=filenames
                self._img_feature_data.append(features)
                
        self._img_feature_data=np.vstack(self._img_feature_data )
        origin_shape=self._img_feature_data.shape
        self._img_feature_data=np.reshape(
            self._img_feature_data,
            (origin_shape[0],origin_shape[3]))
        self._img_feature_filenames=np.asarray(self._img_feature_filenames)
        
        print(self._img_feature_data,shape)
        print(self._img_feature_filenames.shape)
        
    
    def size(self):
        return len(self._img_feature_filenames)
    
    def img_feature_size(self):
        return self._img_feature_data.shape[1]
    
    def _random_shuffle(self):
        p=np.random.permutation(self.size())
        self._img_feature_filenames=self._img_feature_filenames[p]
        self._img_feature_data=self._img_feature_data[p]
        
    def _img_desc(self,batch_filenames):
        batch_sentence_ids=[]
        batch_weights=[]
        for filename in batch_filenames:
            token_ids_set=self._img_name_to_tokens_id[filename]
            chosen_token_ids=random.choice(token_ids_set)
            chosen_token_ids_length=len(chosen_token_ids)
            
            weight=[1 for i range(chosen_token_ids_length)]
            if chosen_token_ids_length>=self._num_timesteps:
                chosen_token_ids=chosen_token_ids[0:self._num_timesteps]
                weight=weight[0:self._num_timesteps]
                
            else:
                remaining_length=self._num_timesteps-chosen_token_ids_length
                chosen_token_ids+=[self._vocab.eos for i in range(remaining_length)]
                weight+=[0 for i in range(remaining_length)]
            batch_sentence_ids.append(chosen_token_ids)
            batch_weights.append(weight)
            
        batch_sentence_ids=np.asarray(batch_sentence_ids)
        batch_weights=np.asarray(batch_weights)
        return batch_sentence_ids,batch_weights
        
    #返回数据给模型
    def next_batch(self,batch_size):
        end_indicator=self.indicator+batch_size
        if end_indicator>self.size():
            if not self._deterministic:
                self._random_shuffle()
            self._indicator=0
            end_indicator=self._indicator+batch_size
        
        assert end_indicator<self.size()
        
        batch_filenames=self._img_feature_filenames[self._indicator:end_indicator]
        batch_img_features=self._img_feature_data[self._indicator:end_indicator]
        
        batch_sentence_ids,batch_weights=self._img_desc[batch_filenames]
        self._indicator=end_indicator
        return batch_img_features,batch_sentence_ids,batch_weights,
        
caption_data= ImageCaptionData(img_name_to_tokens_id,input_img_feature_dir,hps.num_timesteps,
                              vocab
                              )
img_feature_dim=caption_data.img_feature_size()
def create_rnn_cell(hidden_dim,cell_type):
    if cell_type =='lstm':
        return tf.contrib.rnn.BasicLSTMCell(hidden_dim,state_is_tuple=True)
    elif cell_type=='gru':
        return tf.contrib.rnn.GLSTMCell(hidden_dim)
    else:
        raise Exception("")
        
def dropout(cell,keep_prob):
    return tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=keep_prob)

#计算图实现
def get_train_model(hps,vocab_size,img_feature_dim):
    num_timesteps=hps.num_timesteps
    batch_size=hps.batch_size
    
    img_feature=tf.placeholder(tf.float32,(batch_size,img_feature_dim))
    sentence=tf.placeholder(tf.int32,(batch_size,num_timesteps))
    #第多少个是填充的
    mask=tf.placeholder(tf.int32,(batch_size,num_timesteps))
    keep_prob=tf.placeholder(tf.float32,name="keep_prob")
    global_step=tf.Variable(tf.zeros([],tf.int32),
                           name="global_step",
                           trainable=False)
    
    #predictoin process:
    
    #ground_truth:sentence:[a,b,c,d,e]
    #img_feature:[0.4,0.3,10,2]
    #img_feature->embedding_img->lstm->(a)
    #predict:a->embedding_word->lstm->(b)
    #...
    
    #Sets up embedding layer
    embedding_initiaizer=tf.random_uniform_initializer(-1.0,1.0)
    with tf.variable_scope('embedding',initializer=embedding_initiaizer):
        embeddings=tf.get_variable('embedding',
                                  [vocab_size,hps.num_embedding_nodes],
                                  tf.float32)
        # embed_token_ids:[batch_size,num_timestep-1,num_embedding_nodes]
        embed_token_ids= tf.nn.embedding_lookup(
            embeddings,
            sentence[:,0:num_timesteps-1])
        img_feature_embed_init=tf.uniform_unit_scaling_initializer(
        factor=1.0)
        with tf.variable_scope('img_feature_embed',
                              initializer=img_feature_embed_init):
            #img_featre:[batch_size,img_fature_dim]
            # embed_img:[batch_size,num_embedding_nodes]
            embed_img=tf.layers.dense(img_feature,hps.num_embedding_nodes)
            # embed_img:[batch_size,1,num_embedding_nodes]
            embed_img=tf.expand_dims(embed_img,1)
            # embed_inputs:[batch_size,num_timesteps,num_embedding_nodes]
            embed_inputs=tf.concat([embed_img,embed_token_ids],axis=1)
        
        # Sets up rnn network
        scale=1.0/math.sqrt(hps.num_embedding_nodes+hps.num_lstm_node)
        rnn_init=tf.random_uniform_initializer(-scale,scale)
        with tf.variable_scope('lstm_nn',initializer=rnn_init):
            #存储每一层的cell
            cells=[]
            for i in range(hps.num_lstm_layer):
                cell=create_rnn_cell(hps.num_lstm_node[i],hps.cell_type)
                cell=dropout(cell,keep_prob)
                cells.append(cell)
            cell=tf.contrib.rnn.MultiRNNCell(cells)
            
            #定义初始化的状态
            init_state=cell.zero_state(hps.batch_size,tf.float32)
            # rnn_outputs:[batch_size,num_timestep,hps.num_lstm_node[-1]]
            rnn_outputs,_=tf.nn.dynamic_rnn(cell,embed_inputs,init_state=init_state)
            
            
            #Sets up fully_connected layer
            fc_init=tf.uniform_unit_scaling_initializer(factor=1.0)
            with tf.variable_scope('fc',initializer=fc_init):
                rnn_otputs_2d=tf.reshape(rnn_outputs,[-1,hps.num_lstm_node[-1]])
                fc1=tf.layers.dense(rnn_otputs_2d,hps.num_fc_nodes,
                                   name="fc1")
                fc1_droupt=tf.contrib.layers.dropout(fc1,keep_prob)
                fc1_relu=tf.nn.relu(fc1_droupt)
                logits=tf.layers.dense(fc1_relu,vocab_size,name="logits")
                
            
            #损失函数
            with  tf.variable_scope('loss'):
                sentence_flatten=tf.reshape(sentence,[-1])
                mask_flatten=tf.reshape(mask,[-1])
                mask_sum=tf.reduce_sum(mask_flatten)
                
                softmax_loss=tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=logits,
                    labels=sentence_flatten)
                weighted_softmax_loss=tf.multiply(
                    softmax_loss,tf.cast(mask_flatten,tf.float32))
                loss=tf.reduce_sum(weighted_softmax_loss)/mask_sum
                
                prediction= tf.arg_max(logits,1,output_type=tf.int32)
                correct_prediction= tf.equal(prediction,sentence_flatten)
                weighted_coorect_prediction=tf.multiply(tf.cast(correct_prediction,tf.float32),
                                                       mask_flatten)
                accuracy=tf.reduce_sum(weighted_coorect_prediction)/mask_sum
                tf.summary.scalar('loss',loss)
                
            #train op
            with tf.variable_scope('train_op'):
                #获取所有可训练的变量
                 tvars=tf.trainable_variables()
                    for var in tvars:
                        logging.info('variable name: %s' % var.name)
                    grads,_=tf.clip_by_global_norm(
                        tf.gradients(loss,tvars),hps.clip_lstm_grads)
                    optimizer=tf.train.AdamOptimizer(hps.learning_rate)
                    train_op=optimizer.apply_gradients(
                        zip(grads,tvars),global_step=global_step)
                    
                    
            return ((img_feature,sentence,mask,keep_prob),
                       (loss,accuracy,train_op),global_step)
        
        
        placeholder,metrics,global_step=get_train_model(
            hps,vocab_size,img_feature_dim)
        
        img_feature,sentence,mask,keep_prob=placeholder
        loss,accuracy,train_op=metrics
        
        summary_op=tf.summary.merge_all()
        init_op=tf.global_variables_initializer()
        saver=tf.train.Saver(max_to_keep=10)

#模型训练
training_steps=1000

with tf.Session() as sess:
    sess.run(init_op)
    writer=tf.summary.FileWriter(output_dir,sess.graph)
    for i in range(training_steps):
        (batch_img_features,
         batch_sentence_ids,
         batch_weights,_)=caption_data.next_batch(hps.batch_size)
        input_vals=(batch_img_features,batch_sentence_ids,batch_weights,hps.keep_prob)
        feed_dict=dict(zip(placeholder,input_vals))
        fetches=[global_step,loss,accuracy,train_op]
        #?????
        should_log=(i+1)% hps.log_frequent==0
        should_save=(i+1)% hps.save_frequent==0
        
        if should_log:
            fetches+=[summary_op]
        #输出
        outputs=sess.run(fetches,feed_dict=feed_dict)
        global_setp_val,loss_val,accuracy_val=outputs[0:3]
        if should_log:
            summary_str=outputs[-1]
            writer.add_summary(summary_str,global_setp_val)
            logging.info('step: %5d,loss%3.3f, accu:%3.3f'
                        % (global_setp_val,loss_val,accuracy_val))
            
        if should_save:
            model_save_file=os.path.join(output_dir,"image_caption")
            logging.info('step: %5d, model saved' % global_setp_val)
            Saver.saver(sess,model_save_file,global_step=global_setp_val)
            
            
#模型训练
training_steps=1000

with tf.Session() as sess:
    sess.run(init_op)

 

 

 

 

 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值