最近在读这篇文章,顺便记录些东东。。。
论文原题目是《Learning Deep Representations of Fine-Grained Visual Descriptions》(链接),程序在GitHub上有(链接),用了Torch框架(总觉得这个框架的文档有点杂乱。。。有挺多坑要去踩的。。。虽然贫僧觉得caffe坑更加多。。。)来搭神经网络(这个框架主要是用Lua
语言,其实和Python
有点像,比较容易上手的还是)。
这群人做了什么
训练出了一种无监督学习模型,能够根据你提供的一句话来搜出满足这句话的图像。
模型训练目标
1 N ∑ n = 1 N Δ ( y n , f v ( v n ) ) + Δ ( y n , f t ( t n ) ) \frac{1}{N} \sum_{n=1}^{N}\Delta(y_n, f_v(v_n)) + \Delta(y_n, f_t(t_n)) N1n=1∑NΔ(yn,fv(vn))+Δ(yn,ft(tn))
视觉信息 v ∈ V v \in V v∈V(这里只是定义,其实用通俗的话来说就是单张图片 v v v属于图片数据库 V V V),文字描述 t ∈ T t \in T t∈T且类别标签 y ∈ Y y \in Y y∈Y,学习函数(就是后面要训练的模型部分) f v : V → Y f_v : V \rightarrow Y fv:V→Y, f t : T → Y f_t: T \rightarrow Y ft:T→Y。这里的 N N N是指数据集中图像-文本对的数量,所以一个图像可以有多个不同的文本描述。
将 Δ : Y × Y → R \Delta : Y \times Y \rightarrow R Δ:Y×Y→R, Δ \Delta Δ是由 0 0 0和 1 1 1构成的损失函数减小到能够接受的程度的时候就是达到了最后目标了。上面这个公式就是DS-SJE(deep symmetric structured joint embedding),如果只优化 f v f_v fv的话那么就是DA-SJE(deep asymmetric structured joint embedding)(如果是只优化另一个的话也可以,但是作者说还没有看到过有人这么做过)。
更加具体的东东这里就不重复了,看下面参考里面的链接吧。
模型优势
不需要人为标定图片的特征,直接在图片和对应的文本上进行训练就可以达到在人为标定特征的数据集上训练的模型的效果(甚至更好),让模型的适用性更强(毕竟人为标定特征的数据集不多,而且工作量也大,应用起来也不方便)。
相关代码阅读
代码网址看上面给出的链接,这里的代码用了的是Torch来写的,要自己看看lua和torch教程,这里就不展开来说了。
如何从Torch7已经训练好了的模型中提取出权重
内容主要关于如何读取Learning Deep Representations of Fine-grained Visual Descriptions论文配套的模型。
建议:在继续读之前请先初步了解下Lua语言的用法、Torch7的使用方法及Torch7的nnGraph包的基础使用方法(不用太精通啊喂,只要能读懂别人的代码就ok得不行啦。把基本操作过一遍就可以了,遇到不懂的再查)。
1. 读取模型的代码
require 'nn';
require 'cudnn';
require 'cunn';
require 'nngraph';
require 'torch';
m = torch.load('a.t7')
a.t7
就是要读取的模型的名字,在执行命令的时候要和模型在同一个目录下(不然就要用到绝对路径)。
1.1 m的keys
- val_loss(数字,貌似没什么用)
- protos(nn.gModule类型,训练好的模型在这里面)
- epoch(应该只是设置相关的数字,可能是每次载入的数据量?)
- train_losses(一堆数字,记录了训练过程中的loss)
- opt(记录配置目录、数据集目录之类的东东)
- val_losses(数字,貌似没什么用)
- i(训练总次数,记录用的)
上面table
里面的文字记录部分主要是用来计算accuracy和evaluate用的,具体的使用方法看相应的脚本。
1.1 读取CVPR2016中真正的模型
其实就是读取训练好了的网络模型,上面说到的读取模型是指训练网络之后保存的.t7格式文件。
protos = m["protos"]
for key, value in pairs(protos) do
print(key)
end
2. PyTorch中读取.t7格式的模型
from torch.utils.serialization import load_lua
x = load_lua('x.t7')
实际上因为论文代码训练出来的模型不能用pytorch读取,因为用到了nngraph包。如果想要在pytorch重现结果的话需要自己重新搭建网络,具体网络的搭建方法看train_sje_hybrid.lua
里面调用了module文件夹里的什么脚本训练出来的。
3. protos的keys
- enc_image
- enc_doc
3.1初步读取到模型权重及相关内容
这里只是读取模型的参数。
th> print(protos.enc_doc:parameters())
{
1 : CudaTensor