relative position encoding快速看懂相对位置编码代码实现

先理解相对位置:
在这里插入图片描述
在这里插入图片描述
举例如果有5个token,相对位置就有9种情况。

def _generate_relative_positions_matrix(length, max_relative_position,
                                        cache=False):
  """Generates matrix of relative positions between inputs."""
  if not cache:
    range_vec = tf.range(length)
    range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
    distance_mat = range_mat - tf.transpose(range_mat)
  else:
    distance_mat = tf.expand_dims(tf.range(-length+1, 1, 1), 0)
  distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
                                          max_relative_position)
  # Shift values to be >= 0. Each integer still uniquely identifies a relative
  # position difference.
  final_mat = distance_mat_clipped + max_relative_position
  return final_mat  # 解释前面是如何得到相对位置的

结合具体例子快速理解前面代码:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

然后是通过前面的相对位置矩阵,得到相对位置编码

def _generate_relative_positions_embeddings(length, depth,
                                            max_relative_position, name,
                                            cache=False):
  """Generates tensor of size [1 if cache else length, length, depth]."""
  with tf.variable_scope(name):
    relative_positions_matrix = _generate_relative_positions_matrix(
        length, max_relative_position, cache=cache)
    vocab_size = max_relative_position * 2 + 1
    # Generates embedding for each relative position of dimension depth.
    embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
    embeddings = tf.gather(embeddings_table, relative_positions_matrix)
    #tf.gather函数是TensorFlow中的一个操作,用于从一个张量中提取指定索引位置的元素。
    return embeddings

大模型解释教学,感觉讲地很清楚:
在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值