tensorflow中关于BahdanauAttention以及LuongAttention实现细节

背景介绍

在 TensorFlow 中,Attention 的相关实现代码是在 tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py 文件中,这里面实现了两种 Attention 机制,分别是 BahdanauAttention 和 LuongAttention,其实现论文分别如下:

Neural Machine Translation by Jointly Learning to Align and Translate,Bahdanau, et al
Effective Approaches to Attention-based Neural Machine Translation, Luong, et al

整个 attention_wrapper.py 文件中主要包含几个类,我们主要关注其中几个:

  • AttentionMechanism、_BaseAttentionMechanism、LuongAttention、BahdanauAttention
    实现了 Attention 机制的逻辑
  • AttentionMechanism 是 Attention 类的父类,继承了 object 类,内部没有任何实现。
  • _BaseAttentionMechanism 继承自 AttentionMechanism 类,定义了 Attention 机制的一些公共方法实现和属性。
  • LuongAttention、BahdanauAttention 均继承 _BaseAttentionMechanism
    类,分别实现了上面两篇论文的 Attention 机制。 AttentionWrapperState 用来存储整个计算过程中的
    state,和 RNN 中的 state 类似,只不过这里额外还存储了 attention、time 等信息。
  • AttentionWrapper 主要用于对封装 RNNCell,继承自 RNNCell,封装后依然是 RNNCell
    的实例,可以构建一个带有 Attention 机制的 Decoder。

另外还有一些公共方法,例如 hardmax、safe_cumpord 等。
下面我们以 BahdanauAttention 为例来说明 Attention 机制及 AttentionWrapper 的实现。

1.BahdanauAttention介绍

BahdanauAttention类,首先看__init__函数:

	def __init__(self,
	 	num_units,
	 	memory,
	 	memory_sequence_length=None,
	 	normalize=False,
	 	probability_fn=None,
	 	score_mask_value=None,
	 	dtype=None,
	 	name="BahdanauAttention"):
  • num_units:神经元节点数,我们知道在计算 eij 的时候,需要使用 si−1 和 hj 来进行计算,而二者的维度可能并不是统一的,需要进行变换和统一,所以这里就有了 Wa 和 Ua 这两个系数,所以在代码中就是用 num_units 来声明了一个全连接 Dense 网络,用于统一二者的维度,以便于下一步的计算:

     query_layer=layers_core.Dense(
      	num_units, name="query_layer", use_bias=False, dtype=dtype)
     memory_layer=layers_core.Dense(
      	num_units, name="memory_layer", use_bias=False, dtype=dtype)
    
  • memory:The memory to query,一般为RNN encoder的输出。维度为[batch_size, max_time, context_dim]。在父类_BaseAttentionMechanism的初始化方法中,

     with ops.name_scope(
      	name, "BaseAttentionMechanismInit", nest.flatten(memory)):
      	self._values = _prepare_memory(
      	memory, memory_sequence_length,
      	check_inner_dims_defined=check_inner_dims_defined)
      	self._keys = (
      	self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable
      	else self._values)
    

首先是使用_prepare_memory函数对memory进行处理,然后使用上面定义的memory_layer对memory进行全连接的维度变换,变换成[batch_size, max_time, num_units]

  • memory_sequence_length:Sequence lengths for the batch entries in memory. 即 memory 变量的长度信息,类似于 dynamic_rnn 中的 sequence_length,被 _prepare_memory() 方法调用处理 memory 变量,进行 mask 操作:

     seq_len_mask = array_ops.sequence_mask(
         memory_sequence_length,
         maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
         dtype=nest.flatten(memory)[0].dtype)
     seq_len_batch_size = (
         memory_sequence_length.shape[0].value
         or array_ops.shape(memory_sequence_length)[0])
    
  • normalize:Whether to normalize the energy term. 即是否要实现标准化,方法出自论文:Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, Salimans, et al。

  • probability_fn:A callable function which converts the score to probabilities. 计算概率时的函数,必须是一个可调用的函数,默认使用 softmax(),还可以指定 hardmax() 等函数。

  • score_mask_value:The mask value for score before passing into probability_fn. The default is -inf. Only used if memory_sequence_length is not None. 在使用 probability_fn 计算概率之前,对 score 预先进行 mask 使用的值,默认是负无穷。但这个只有在 memory_sequence_length 参数定义的时候有效。

  • dtype:The data type for the query and memory layers of the attention mechanism. 数据类型,默认是 float32。

  • name:Name to use when creating ops,自定义名称。

然后看__call__()函数:

def __call__(self, query, state):
       with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
			processed_query = self.query_layer(query) if self.query_layer else query
			score = _bahdanau_score(processed_query, self._keys, self._normalize)
			alignments = self._probability_fn(score, state)
			next_state = alignments
			return alignments, next_state

call函数首先对query进行全连接层的维度变换,然后调用_bahdanau_score函数计算score,也就是eij,然后调用_probability_fn函数计算softmax.

  • 在_bahdanau_score函数中,_key函数表示Encoder的输出,也即是memory的变换后的值。procesed_query值为decoder 隐藏层。_bahdanau_score函数部分代码如下所示:

     if normalize:
            # Scalar used in weight normalization
            g = variable_scope.get_variable(
            "attention_g", dtype=dtype,
             initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))),
                                                                            shape=())
             # Bias added prior to the nonlinearity
             b = variable_scope.get_variable("attention_b", [num_units], dtype=dtype,
                                                                  initializer=init_ops.zeros_initializer())
             # normed_v = g * v / ||v||
              normed_v = g * v * math_ops.rsqrt(math_ops.reduce_sum(math_ops.square(v)))
              return math_ops.reduce_sum(normed_v * math_ops.tanh(keys + processed_query + b), [2])
     else:
              return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])
    

从代码中可以看出,_bahdanau_score函数主要有两个作用,一个是计算eij,另一个是对eij进行weighted normalization处理。
这里写图片描述

score计算的方式有点类似concat的方式。

  • _probability_fn函数如果不直接指定的话,默认的值为softmax函数。

2.LuongAttention介绍

与BahdanauAttention相比,LuongAttention在具体实现上相似,只是在代码细节上略有不同。下面进行详细的介绍:

  • 首先,在__init__函数中,只是简单的定义了memory_layer,代码如下所下所示:

    	 super(LuongAttention, self).__init__(
    			query_layer=None,
    			memory_layer=layers_core.Dense(
    			num_units, name="memory_layer", use_bias=False, dtype=dtype),
    			memory=memory,
    			probability_fn=wrapped_probability_fn,
    			memory_sequence_length=memory_sequence_length,
    			score_mask_value=score_mask_value,
    			name=name)
    
  • 其次,在__call__函数中,结构相似,主要区别是将socre函数变成了_luong_score函数。

  • 最后,在_luong_score函数中,主要代码如下:

     score = math_ops.matmul(query, keys, transpose_b=True)
     score = array_ops.squeeze(score, [1])
     
     if scale:
             # Scalar used in weight scaling
             g = variable_scope.get_variable(
                   "attention_g", dtype=dtype,
             initializer=init_ops.ones_initializer, shape=())
             score = g * score
    

这里实现的是简单的相乘的方式。不过需要注意的一点是,在attention的父类_BaseAttentionMechanism中,已经对self._values值进行dense处理,处理后的结果就是key。

相关链接:https://cuiqingcai.com/5873.html

要在TensorFlow实现Cross-Attention模块,您可以按照以下步骤进行操作: 1. 导入所需的库和模块: ```python import tensorflow as tf from tensorflow.keras.layers import Layer, Dense ``` 2. 创建一个自定义的CrossAttention层: ```python class CrossAttention(Layer): def __init__(self, units): super(CrossAttention, self).__init__() self.units = units def build(self, input_shape): self.W1 = self.add_weight(shape=(input_shape[0][-1], self.units), initializer='random_normal', trainable=True) self.W2 = self.add_weight(shape=(input_shape[1][-1], self.units), initializer='random_normal', trainable=True) self.b = self.add_weight(shape=(self.units,), initializer='zeros', trainable=True) super(CrossAttention, self).build(input_shape) def call(self, inputs): query, value = inputs q = tf.matmul(query, self.W1) # Query的线性变换 k = tf.matmul(value, self.W2) # Value的线性变换 scores = tf.matmul(q, tf.transpose(k, [0, 2, 1])) # 计算注意力分数 attention_weights = tf.nn.softmax(scores) # 对注意力分数进行softmax归一化 output = tf.matmul(attention_weights, value) + self.b # 加权求和 return output ``` 3. 使用CrossAttention层: ```python # 创建模型 input_query = tf.keras.Input(shape=(query_len, input_dim)) input_value = tf.keras.Input(shape=(value_len, input_dim)) cross_attention = CrossAttention(units=hidden_dim) output = cross_attention([input_query, input_value]) model = tf.keras.Model(inputs=[input_query, input_value], outputs=output) ``` 在上述代码,我们首先定义了一个自定义的CrossAttention层,其build()函数用于创建权重。然后,在call()函数,我们按照Cross-Attention的计算公式进行操作:通过线性变换获得Query和Value的表示,计算注意力分数,使用softmax归一化注意力分数,最后对Value进行加权求和。最后,我们使用这个CrossAttention层构建了一个模型,并将输入数据传递给该模型以获取输出。 请注意,上述代码仅为示例,您可能需要根据自己的具体需求进行修改和调整。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一夜了

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值