一、基本思想
CNN+RNN CNN用的是VGG16 RNN部分用的是LSTM.换成resnet101效果会更好。
二、模型结构
四、代码分析:
首先是训练的部分
(1)准备数据
COCO数据集中的caption最长不能超过16个单词,超出的话只截取前16个。每张图保留5个caption。
根据词频建立词汇库,一般选出现次数>5次的单词,flickr数据集一般是3次。自己建立的词汇库效果不是很好,网上有分好的data_coco.json和data_flickr.json,可以直接拿来用,效果很好。同时,里面也已经对数据集(训练集 验证集和测试集)做了划分 。
coco 110000+5000+5000
flickr 28000+1000+1000
(2)建立模型
分为两个部分
1)CNN: 使用vgg16模型在Imagenet上预训练的权重。提取最后的全连接层特征。
2)LSTM
LSTM的输入向量、 隐含层向量都是512维。每个词汇库里的单词用embedding来表示。比如torch.nn.embedding() 它随着训练过程一起变化。
关于训练的细节,附一张图:
每个句子的开头加了一个<bos>,结尾加了一个<eos> 都用0表示,句子长度为1+16+1
其他单词用 在词汇库中的索引表示。
和(二)中的图对比,最初的时刻(0时刻)输入图像特征,但是并不做预测,或者说预测的结果根本不使用。从1时刻开始,输入<bos>的编码,预测下一个单词的概率,每一个时刻的输出为词汇库大小的概率向量。
因为预测结果包含了<eos> 所以最后的prediction长度为16+1=17
训练过程中用mask记录每个caption的长度。<eos>以及后面不到长度的部分,全部为0,前面是1,
我们希望每个单词预测出来的词汇库大小的向量,gt单词位置处的概率最大,最接近于1.对应的-logP最小。