Ba, Jimmy, Volodymyr Mnih, and Koray Kavukcuoglu. “Multiple object recognition with visual attention.” arXiv preprint arXiv:1412.7755 (2014).
思想
三位作者均来自于风头正劲的Google DeepMind,三作Koray Kavukcuoglu在AlphaGo的Nature论文中榜上有名。
本文执行的任务相对简单:从图片中识别长度、位置未知的手写数字串。但包含了当今神经网络的诸多热点方向,包括:
- 聚焦机制(Attention):每次只看输入的一小部分,诸次移动观察范围。
- 循环神经网络(Recurrent NN):在每一次移动和输出之间建立记忆
- 增强学习(Reinforcement learning):在训练过程中,根据不可导的反馈,从当前位置产生探索性的采样。
本文和前一篇文章中介绍的RAM(Recurrent Visual Attention Model)算法极为相似,但是更侧重数学推导。建议先阅读这篇博客中的解读。
对于增强学习没概念的同学,也可以参考这篇博客:Torch中的增强学习层
模型
核心数据
X X X: 输入图像
n n n: 步骤序号,共有 N N N个步骤,每次查看图像一小部分。
l n l_n ln: 第 n n n步查看的图像位置。整数类型xy坐标,图像中心为(0,0),图像边缘对应的坐标为系统超参数,决定搜索粒度。
x n x_n xn: 第 n n n步观察到的图像内容,称为glimpse。是以 l n l_n ln为中心,尺寸相同,缩放和范围等差的图像金字塔。
特别要注意的是: x n x_n xn没法对 l n l_n ln求导。
子网络
整个系统由若干部分组成,执行不同功能。系统的组成部件都称为网络。
系统中变量繁多,不必急于看全图,顺序推导即可。
Glimpse网络
输入:当前位置 l n l_n ln,当前图像块 x n x_n xn
输出:当前观察的信息 g n g_n gn
形式:
g n = G i m a g e ( x n ∣ W i m a g e ) ⊙ G l o c ( l n ∣ W l o c ) g_n = G_{image}(x_n|W_{image})\odot G_{loc}(l_n|W_{loc}) gn=Gimage(xn∣Wimage)⊙Gloc(ln∣Wloc)
G i m a g e G_{image} Gimage和 G l o c G_{loc} Gloc是两个网络,其参数为 W i m a g e W_{image} Wimage和 W l o c W_{loc} Wloc。分别把图像(what)和位置(where)编码成统一维度的信息,进行点乘。
作用:通过小范围观测,提取纹理和位置信息。
条件号后面的 W ∗ W_* W∗表示某网络参数,此后不再赘述。
Recurrent网络
输入:当前观察信息 g n g_n gn,上一步状态 r n − 1 1 , r n − 1 2 r_{n-1}^1,r_{n-1}^2 rn−11,rn−12
输出:当前的两个循环状态 r n 1 , r n 2 r_n^1,r_n^2 rn1,rn2
形式:
r n 1 = R r e c u r ( g n , r n − 1 1 ∣ W r 1 ) r_n^1 = R_{recur}(g_n,r_{n-1}^1|W_{r1}) rn1=Rrecur(gn,rn−11∣Wr1)
r n 2 = R r e c u r ( r n 1 , r n − 1 2 ∣ W r 2 ) r_n^2 = R_{recur}(r_n^1,r_{n-1}^2|W_{r2}) rn2=Rrecur