简介
最近的一个工作需要用到MIMNCell,但是原本的论文其实一篇比较工业化的论文,里面对于离线部分的MIMN可以说完全没有解释,我一步一步的将官方的实现在这里做一些分享。
官方paper在这里:
传送门
MIMNCell的主要模块
首先MIMNCell我认为是 RNN结构的一种改进。当然其中增加了很多的模块,但是输入仍然是一个时序序列。
其中主要包括:
- Controller
- Memory Read
- Memory Write
- MIU部分,其中重要的是理解MIU部分维护的S矩阵
基本的工作流程是:
controller
在这里我们假设 MIMNCell的输入是 x;以及上一个MIMNCell的各种状态,如果是第一次输入,也就是t=0的情况下,就用0初始化。
首先根据将 x和上一个状态的Memory Read的输出 read_vector 输入到Controller中,输入的时候需要将x和read_vector 拼接起来
而Controller就是一个标准的GRU。
然后对于Controller这个GRU来说,inputs = x和read_vector的concat,初始状态就是上一个controller上一个时间步的state。
这样我们就得到了一个controller_output & controller_state.
根据一个fully_connect layer,输入是controller_output, 输出是一个非常大的维度(包括head_parameter & erase_add parameters):
- 前num_parameters_per_head * num_heads是head的参数向量
其中num_parameters_per_head = memory_vector_dim(超参)
其中num_head = read_head_num + write_head_num
- 后self.memory_vector_dim * 2 * self.write_head_num 表示的是erase parameter 和 add parameter
这些paramerter都是根据controller_output生成,然后用于下面的处理。
NTM 部分(memory read & write)
NTM维护着一个M矩阵,read和write也是用于更新和修改这个M矩阵的。
对于read和write部分,每一次输入第t个behavior vector,都会生成一个paper中叫weight vector,这里无论是read head还是write head都是一样的weight vector的获取方式,论文中的部分我直接复制在这里:
但是这里首先
k
t
k_t
kt的生成论文中并没有说明,
在coding中
k
t
k_t
kt的获取是利用每个head parameter中的memory_vector_dim经过一个tanh激活函数得到。然后经过上图的计算就可以为每个head(read & write)都获得一个w向量。
注:
在后面的操作中,memory read的部分是利用read head,然后read head利用的是相对应的w向量
memory read的部分是利用read write,然后write head利用的是相对应的w向量
对每个read head
每一个read head都会输出一个read vector,
这里需要注意的是此时的M还是没有更新过的,也就是t-1状态的M。
到这里为止,memory read就完成了他的 工作。
现在其实我们是在介绍NTM的部分,还没有介绍memory write是如何更新,但是现在Memory read已经获取到了输出,所以现在就到了MIU的部分。memory write对于M的更新是在MIU更新之后。
MIU部分
翻译过来就是memory 归纳单元。
这里我认为是paper和coding的实现差别有点大的地方。
- 首先,根据第一个read head的w向量,找到M中权重最大的slot。这里假设是index=0的slot,然后进行一个one-hot编码,获得一个mask。
mask = [1, 0, 0, …],当然别忘了要考虑batchsize,这里只是简单举一个例子。 - 另外维护一个channel_rnn,就是一个orignal GRU结构。
向channel_rnn输入 concat([x, M*mask]),初始化state就是t-1 step的channel_rnn state。
这里注意M是t-1 step的M,因为在t step,还没有利用memory write对M进行更新。
然后更新S
S = channel_rnn_state * mask + channel_rnn_prev_state * (1-mask)
这里可以理解为用w中权重最大的index对应的channel_rnn的state来更新t-1 step的旧state,然后就获得了S。
同时获得t step的channel_rnn的输出:
ouput_t. = channel_rnn_output * mask + channel_rnn_prev_ouput * (1-mask)
上面这个过程就是论文中的:
回到NTM中的Memory write
论文中的memory write的更新如下所示:
其中
e
t
e_t
et 和
a
t
a_t
at 就是上述erase parameter 和 add parameter,
w
t
w
w_t^w
wtw表示的就是write w向量,不同的write head对应不同的w。
然后就可以对M进行更新。
总结
这样一个完整的t step的MIMNCell的更新就是这样完成的。
可能我的逻辑有一些不好理解,如果有问题和指正欢迎大家直接留言。
大家共勉~~