输出和状态是一样的,前一个状态为state,前一个输出也是state,其宽度都是num_units参数
重置门和更新门分别是r和u
首先输入和前一个输出拼接在一起,然后加权(_gate_kernel)再按列平分(因为r,u都是对状态的加权,所以宽度和状态的宽度一样,都是num_units参数),得到重置门r和更新门u
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._gate_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
value = math_ops.sigmoid(gate_inputs)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
然后对state进行重置(遗忘)
r_state = r * state
遗忘之后的状态和输入拼接在一起,加权(_candidate_kernel)得到候选状态candidate,接着激活
candidate = math_ops.matmul(
array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
candidate = nn_ops.bias_add(candidate, self._candidate_bias)
c = self._activation(candidate)
最后对原状态和候选状态各取一定比例叠加在一起,得到新状态,和新输出
new_h = u * state + (1 - u) * c
return new_h, new_h