Neural Turing Machines - 神经图灵机
参考论文:Graves, A., Wayne, G., & Danihelka, I… (2014). Neural Turing Machines.
引言
一般的神经网络不具有记忆,LSTM可以利用之前的信息,但随着序列的加深,靠前的信息难以利用。
NTM
神经图灵机(NTM)通过利用在外部进行记忆的存储,读取与写入解决了这一问题。神经图灵机 (NTM) 架构包含两个基本组件:神经网络控制器和存储库。像大多数神经网络一样,控制器通过输入和输出向量与外部世界交互。与标准网络不同,它还使用选择性读写操作与内存矩阵进行交互。通过类比图灵机,将参数化这些操作的网络输出称为“头”。
下图是NTM的结构,在每个更新周期中,控制器网络从外部环境接收输入并发出输出作为响应。它还通过一组并行读写“头”读取和写入内存矩阵。虚线表示 NTM 与外界的划分。神经图灵机使所有的读写操作都可微分化,因此可以用神经网络误差后向传播的方式去训练模型。
读取记忆(Reading)
我们把记忆看作是一个
N
×
M
N×M
N×M的矩阵
M
t
M_t
Mt,
t
t
t表示当前时刻, 表示记忆会随着时间发生变化。我们的读过程就是生成一个定位权值向量
w
t
w_t
wt,长度为
N
N
N,表示
N
N
N个位置对应的记忆权值大小,将权重归一化:
∑
i
w
t
(
i
)
=
1
,
0
≤
w
t
(
i
)
≤
1
,
∀
i
.
\sum_iw_t(i)=1,\quad \quad 0\leq w_t(i)\leq 1,\forall i.
i∑wt(i)=1,0≤wt(i)≤1,∀i.
最后读出的记忆向量r_t为:
r
t
←
∑
i
w
t
(
i
)
M
t
(
i
)
r_t\leftarrow \sum_iw_t(i)M_t(i)
rt←i∑wt(i)Mt(i)
可以看到,读取操作实际上就是对
N
N
N条记忆加权求和,相当于attention。
写入记忆(Writing)
NTM的写入部分从LSTM 中的输入和遗忘门中汲取灵感,将每个写入分解为两部分:擦除与添加。NTM通过擦除向量
e
t
e_t
et (erase vector) 和一个增加向量
a
t
a_t
at (add vector),长度都为
N
N
N,向量中每个元素的值大小范围从0到1,表示要增加或者删除的信息。对于写记忆过程,神经图灵机首先执行一个擦除操作,擦除程度的大小同样由向量
w
t
w_t
wt决定:
M
^
t
=
M
t
−
1
(
i
)
(
1
−
w
t
(
i
)
e
t
)
\hat{M}_t=M_{t−1}(i)(1−w_t(i)e_t)
M^t=Mt−1(i)(1−wt(i)et)
而添加过程如下:
M
t
(
i
)
=
M
^
t
(
i
)
+
w
t
(
i
)
a
t
M_t(i)=\hat{M}_t(i)+w_t(i)a_t
Mt(i)=M^t(i)+wt(i)at
其中擦除和添加向量都有
M
M
M个独立成分,从而可以精细控制每个内存位置中的哪些元素被修改。
定位机制(Addressing Mechanisms)
NTM结合了基于内容的(content-based)和基于位置的(location-based)方法,提出一个定位机制用于生成定位向量
w
t
w_t
wt,下图是定位机制的流程图。
按内容聚焦(Focusing by Content)
NTM的相似性度量是余弦相似度:
K
[
u
,
v
]
=
u
⋅
v
∣
∣
u
∣
∣
⋅
∣
∣
u
∣
∣
K[u,v]=\frac{u⋅v}{||u||⋅||u||}
K[u,v]=∣∣u∣∣⋅∣∣u∣∣u⋅v
首先控制器给出一个长度为M的
k
t
kt
kt向量作为查询的key,然后计算
k
t
k_t
kt与
M
t
M_t
Mt中各个记忆向量的相似度
K
[
.
,
.
]
K[.,.]
K[.,.],最后经过Softmax操作得到基于内容的定位向量
w
t
c
w^c_t
wtc:
w
t
c
(
i
)
=
e
x
p
(
β
t
K
[
k
t
,
M
t
(
i
)
]
)
∑
j
e
x
p
(
β
t
K
[
k
t
,
M
t
(
j
)
w_t^c(i)=\frac{exp(\beta_tK[k_t,M_t(i)])}{\sum_jexp(\beta_tK[k_t,M_t(j)}
wtc(i)=∑jexp(βtK[kt,Mt(j)exp(βtK[kt,Mt(i)])
按位置聚焦(Focusing by Location)
基于位置的寻址机制旨在促进内存位置的简单迭代和随机访问跳转,它通过实现权重的旋转移位来实现。如,如果当前权重完全集中在单个位置,旋转 1 会将焦点转移到下一个位置。
- 插值(interpolation):
在旋转之前,每个head都会在 (0, 1) 范围内生成一个标量插值门 g t g_t gt,该值用于混合head在上一个时间步产生的权重 w t − 1 w_{t-1} wt−1和当前步产生的权重 w t c w^c_t wtc,产生门控权重 w t g w^g_t wtg:
w t g ← g t w t c + ( 1 − g t ) w t − 1 w_t^g \leftarrow g_tw_t^c+(1-g_t)w_{t-1} wtg←gtwtc+(1−gt)wt−1
如果门为零,则完全忽略内容权重,并使用前一时间步的权重。相反,如果门是一,则忽略前一次迭代的权重,系统应用基于内容的定位。 - 移位(shift):
在插值后实行移位操作,对于 w t g w_t^g wtg中的每个位置元素 w t g ( i ) w_t^g(i) wtg(i),我们考虑它相邻的k个偏移元素,认为这k个元素与 w t g ( i ) w_t^g(i) wtg(i)相关。例如,如果允许在-1和1之间移动, s t s_t st包含对应于执行 -1、0 和 1 的移动的的3个权值。定义移位权重的最简单方法是使用附加到控制器的适当大小的 softmax 层进行多分类。
除此之外,本文尝试了另外一种方法,在控制器中产生一个缩放因子,该因子为移位位置上均匀分布的下界。比如,如果该缩放因子值为6.7,那么 s t ( 6 ) = 0.3 , s t ( 7 ) = 0.7 , s t s_t(6)=0.3,s_t(7)=0.7,s_t st(6)=0.3,st(7)=0.7,st的其余分量为0(只取整数索引)。
如果索引从0到N-1的N个内存位置,旋转可定义为循环卷积操作:
w ~ t ( i ) ← ∑ j = 0 N − 1 w t g ( j ) s t ( i − j ) \widetilde{w}_t(i) \leftarrow \sum\limits_{j=0}^{N-1}w_t^g(j)s_t(i-j) w t(i)←j=0∑N−1wtg(j)st(i−j)
由于卷积操作会使权值的分布趋于均匀化,这将导致本来集中于单个位置的焦点出现发散现象。为了解决这个问题,还需要对结果进行锐化操作。具体做法是Head产生一个因子 γ t ≥ 1 γ_t≥1 γt≥1,并通过如下操作来进行锐化:
w t ( i ) = w ~ t ( i ) γ t ∑ j w ~ t ( j ) γ t w_t(i)=\frac{\widetilde{w}_t(i)^{\gamma_t}}{\sum_j \widetilde{w}_t(j)^{\gamma_t}} wt(i)=∑jw t(j)γtw t(i)γt
矩阵形式:
S t = [ s t ( 0 ) s t ( N − 1 ) ⋯ s t ( 2 ) s t ( 1 ) s t ( 1 ) s t ( 0 ) s t ( N − 1 ) ⋯ s t ( 2 ) ⋮ s t ( 1 ) s t ( 0 ) ⋱ ⋮ s t ( 3 ) ⋱ ⋱ ⋱ s t ( N − 1 ) s t ( N − 1 ) s t ( N − 2 ) ⋯ s t ( 1 ) s t ( 0 ) ] \textbf{S}_t = \left[ \begin{array}{ccc} s_t(0) &s_t(N-1) & \cdots & s_t(2) & s_t(1)\\ s_t(1) &s_t(0) & s_t(N-1) & \cdots & s_t(2)\\ \vdots &s_t(1) & s_t(0) & \ddots & \vdots\\ s_t(3) &\ddots & \ddots & \ddots & s_t(N-1)\\ s_t(N-1) &s_t(N-2) & \cdots & s_t(1) & s_t(0)\\ \end{array} \right] St=⎣⎢⎢⎢⎢⎢⎡st(0)st(1)⋮st(3)st(N−1)st(N−1)st(0)st(1)⋱st(N−2)⋯st(N−1)st(0)⋱⋯st(2)⋯⋱⋱st(1)st(1)st(2)⋮st(N−1)st(0)⎦⎥⎥⎥⎥⎥⎤
w ~ t = S t w t g \widetilde{\textbf{w}}_t=\textbf{S}_t\textbf{w}^g_t w t=Stwtg
锐化操作示例(参考Neural Turing Machines-NTM系列(一)简述):
假设N=5,当前焦点为1,三个位置-1,0,1对应的权值为0.1,0.8,0.1,
w
t
g
=
[
0.06
0.1
0.65
0.15
0.04
]
\textbf{w}_t^g=\left[ \begin{array}{ccc} 0.06\\ 0.1\\ 0.65\\ 0.15\\ 0.04\\ \end{array} \right]
wtg=⎣⎢⎢⎢⎢⎡0.060.10.650.150.04⎦⎥⎥⎥⎥⎤,因此:
S
t
=
[
s
t
(
0
)
s
t
(
4
)
s
t
(
3
)
s
t
(
2
)
s
t
(
1
)
s
t
(
1
)
s
t
(
0
)
s
t
(
4
)
s
t
(
3
)
s
t
(
2
)
s
t
(
2
)
s
t
(
1
)
s
t
(
0
)
s
t
(
4
)
s
t
(
3
)
s
t
(
3
)
s
t
(
2
)
s
t
(
1
)
s
t
(
0
)
s
t
(
4
)
s
t
(
4
)
s
t
(
3
)
s
t
(
2
)
s
t
(
1
)
s
t
(
0
)
]
=
[
0.1
0
0
0.1
0.8
0.8
0.1
0
0
0.1
0.1
0.8
0.1
0
0
0
0.1
0.8
0.1
0
0
0
0.1
0.8
0.1
]
\textbf{S}_t = \left[ \begin{array}{ccc} s_t(0) &s_t(4) & s_t(3) & s_t(2) & s_t(1)\\ s_t(1) &s_t(0) & s_t(4) & s_t(3) & s_t(2)\\ s_t(2) &s_t(1) & s_t(0) & s_t(4) & s_t(3)\\ s_t(3) &s_t(2) & s_t(1) & s_t(0) & s_t(4)\\ s_t(4) &s_t(3) & s_t(2) & s_t(1) & s_t(0)\\ \end{array} \right]=\left[ \begin{array}{ccc} 0.1&0 & 0 & 0.1 &0.8\\ 0.8 &0.1 & 0 & 0 & 0.1\\ 0.1 &0.8 & 0.1 & 0 & 0\\ 0 &0.1 & 0.8 & 0.1 & 0\\ 0 &0 & 0.1 & 0.8 & 0.1\\ \end{array} \right]
St=⎣⎢⎢⎢⎢⎡st(0)st(1)st(2)st(3)st(4)st(4)st(0)st(1)st(2)st(3)st(3)st(4)st(0)st(1)st(2)st(2)st(3)st(4)st(0)st(1)st(1)st(2)st(3)st(4)st(0)⎦⎥⎥⎥⎥⎤=⎣⎢⎢⎢⎢⎡0.10.80.10000.10.80.10000.10.80.10.1000.10.80.80.1000.1⎦⎥⎥⎥⎥⎤
因此:
w
~
t
=
S
t
w
t
g
=
[
0.1
0
0
0.1
0.8
0.8
0.1
0
0
0.1
0.1
0.8
0.1
0
0
0
0.1
0.8
0.1
0
0
0
0.1
0.8
0.1
]
×
[
0.06
0.1
0.65
0.15
0.04
]
=
[
0.053
0.062
0.151
0.545
0.189
]
\widetilde{\textbf{w}}_t=\textbf{S}_t \textbf{w}_t^g= \left[ \begin{array}{ccc} 0.1&0 & 0 & 0.1 &0.8\\ 0.8 &0.1 & 0 & 0 & 0.1\\ 0.1 &0.8 & 0.1 & 0 & 0\\ 0 &0.1 & 0.8 & 0.1 & 0\\ 0 &0 & 0.1 & 0.8 & 0.1\\ \end{array} \right] \times \left[ \begin{array}{ccc} 0.06\\ 0.1\\ 0.65\\ 0.15\\ 0.04\\ \end{array} \right]=\left[ \begin{array}{ccc} 0.053\\ 0.062\\ 0.151\\ 0.545\\ 0.189\\ \end{array} \right]
w
t=Stwtg=⎣⎢⎢⎢⎢⎡0.10.80.10000.10.80.10000.10.80.10.1000.10.80.80.1000.1⎦⎥⎥⎥⎥⎤×⎣⎢⎢⎢⎢⎡0.060.10.650.150.04⎦⎥⎥⎥⎥⎤=⎣⎢⎢⎢⎢⎡0.0530.0620.1510.5450.189⎦⎥⎥⎥⎥⎤
取
γ
t
=
2
γ_t=2
γt=2,
w
t
=
w
~
t
γ
t
∑
j
w
~
t
(
j
)
γ
t
=
[
0.0078
0.0106
0.0630
0.8201
0.0986
]
\textbf{w}_t=\frac{\widetilde{\textbf{w}}_t^{\gamma_t}}{\sum_j \widetilde{w}_t(j)^{\gamma_t}}=\left[ \begin{array}{ccc} 0.0078\\ 0.0106\\ 0.0630\\ 0.8201\\ 0.0986\\ \end{array} \right]
wt=∑jw
t(j)γtw
tγt=⎣⎢⎢⎢⎢⎡0.00780.01060.06300.82010.0986⎦⎥⎥⎥⎥⎤
最后得到最终的
w
t
w_t
wt用于读取和写入数据。
NTM的实现:
源码
https://github.com/loudinthecloud/pytorch-ntm
# 读过程
class NTMReadHead(NTMHeadBase):
def __init__(self, memory, controller_size):
super(NTMReadHead, self).__init__(memory, controller_size)
# Corresponding to k, β, g, s, γ sizes from the paper
self.read_lengths = [self.M, 1, 1, 3, 1]
self.fc_read = nn.Linear(controller_size, sum(self.read_lengths))
self.reset_parameters()
def create_new_state(self, batch_size):
# The state holds the previous time step address weightings
return torch.zeros(batch_size, self.N)
def reset_parameters(self):
# Initialize the linear layers
nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4)
nn.init.normal_(self.fc_read.bias, std=0.01)
def is_read_head(self):
return True
def forward(self, embeddings, w_prev):
"""NTMReadHead前进函数。
:param embeddings:控制器的输入表示形式。
:param w_prev:上一步状态
"""
o = self.fc_read(embeddings)
k, β, g, s, γ = _split_cols(o, self.read_lengths)
# Read from memory
w = self._address_memory(k, β, g, s, γ, w_prev)
r = self.memory.read(w)
return r, w
# 写过程
class NTMWriteHead(NTMHeadBase):
def __init__(self, memory, controller_size):
super(NTMWriteHead, self).__init__(memory, controller_size)
# Corresponding to k, β, g, s, γ, e, a sizes from the paper
self.write_lengths = [self.M, 1, 1, 3, 1, self.M, self.M]
self.fc_write = nn.Linear(controller_size, sum(self.write_lengths))
self.reset_parameters()
def create_new_state(self, batch_size):
return torch.zeros(batch_size, self.N)
def reset_parameters(self):
# Initialize the linear layers
nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4)
nn.init.normal_(self.fc_write.bias, std=0.01)
def is_read_head(self):
return False
def forward(self, embeddings, w_prev):
"""NTMWriteHead前进函数。
:param embeddings:控制器的输入表示形式。
:param w_prev:上一步状态
"""
o = self.fc_write(embeddings)
k, β, g, s, γ, e, a = _split_cols(o, self.write_lengths)
# e should be in [0, 1]
e = F.sigmoid(e)
# Write to memory
w = self._address_memory(k, β, g, s, γ, w_prev)
self.memory.write(w, e, a)
return w