Pytorch学习笔记-第十章图像描述
记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第十章图像描述的学习笔记。
主要分析实现代码里面main,data,data_preprocess,feature_extract这5个文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。
model
整个工作的模型如下。左边是卷积网络,实现代码里面用的是resnet,用来处理输入图片,获取高层特征;右边是LSTM,用来生成图片的描述(具体过程类似第9章里生成诗歌的过程)。
其中图片经过resnet处理之后会变成2048维的一个向量,最后经过一个fc层变成和词向量一样的256维数据作为LSTM的第一个输入去预测图片的描述。
data_preprocess
原始的图片描述数据是JSON结构化文件,为了方便后续使用,该部分代码进行一些预处理,进行中文分词,以及丢弃词频不够的词,丢弃长度过长的词,最后生成编号和词语互相对应以及图片编号和描述编号互相对应的字典。
data
由于不同的描述可能长度不一样,真正使用之前需要补齐成同一长度。
for i, c in enumerate(caps):
end_cap = lengths[i] - 1
if end_cap < batch_length:
cap_tensor[end_cap, i] = eos
cap_tensor[:end_cap, i].copy_(c[:end_cap]