参考自 GitHub BMINF项目
直接上代码
class PositionBias(Layer):
def __init__(self, num_buckets, num_heads, is_decoder):
self.num_buckets = num_buckets
self.is_decoder = is_decoder
self.num_heads = num_heads
self.embedding = Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow(改编自Mesh TensorFlow)
1. 计算方式大概是 key 序列中,跟当前query位置接近的(左边或者右边),线性增加;
当增加到一定程度时,以log增加,避免偏移过大,也表示 比较远的key就不用从位置这个角度区分太多;
此时有些key会具有相同偏移量,这部分key被称为在一个桶(bucket)里面;
2. 参数说明:bidirectional用于判断是encoder(双向计算)还是decoder(单向计算/左侧计算)
3. 以下注释以 num_buckets=32, max_distance=128 的设置 为背景
"""
relative_buckets = 0
if bidirectional: # 如果是encoder
# 32/2=16
num_buckets //= 2
# 大于0的部分+16,结果某一行大致是 [0,0,0,0,0,16,16,16,16,16...],
# 因为relative_position对应的行是 [-4,-3,-2,-1,0,1,2,3,4,5...];
# 这样做是为了区分 左侧和右侧
relative_buckets += (relative_position > 0).astype(np.int32) * num_buckets
# 取绝对值,上面那一行变为 [4,3,2,1,0,1,2,3,4,5...]
relative_position = np.abs(relative_position)
else: # 如果是decoder
# 上面那一行变为 [4,3,2,1,0,0,0,0,0...],表示decoder只关注左侧(自回归)
relative_position = -np.clip(relative_position, None, 0) # 负的(在左边)不变,正的为0
# 这里设置要线性增加偏移量的范围:is_small用来区分线性和非线性两部分
# 对于encoder,由于是双向的,所以max_exact=16//2=8;
# decoder只考虑左侧,因此max_exact=16
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# 注:以下两行代码和源代码不一样,为了可理解性我做了轻微修改
t1 = relative_position.astype(np.float32)
# 这里就是主要的计算了,t1内的数值小于max_exact时(即当前key位置靠近query位置),中括号代表的值很可能<0,导致整个式子值<0;
# 只要稍微偏离query位置,中括号值就越来越大,直到t1内的值接近max_exact,此时中括号值=0,整体值=max_exact>0;
# key偏离query越远,则整体值将越来越大(>max_exact)
# 这个过程就是由log控制的非线性偏移;
relative_postion_if_large = max_exact + [
np.log(t1/max_exact + 1e-6) / math.log(max_distance/max_exact)
* (num_buckets - max_exact)
].astype(np.int32) # 括号内数值转换为int
# 这里做一些规整,限制最小最大偏移量
relative_postion_if_large = np.clip(relative_postion_if_large, 0, num_buckets - 1)
# 上述非线性偏移同样作用到is_small定义的线性段,因此通过is_small进行过滤非线性偏移量
relative_buckets += np.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets
def forward(self, allocator : Allocator, query_len, key_len):
# 以下3行计算relative_position,即怎么定义query和key之间的相对偏移;
context_position = np.arange(query_len, dtype=np.int32)[:, np.newaxis]
memory_position = np.arange(key_len, dtype=np.int32)[np.newaxis, :]
relative_position = memory_position - context_position # 双向偏移,比如对于keypos=5,与mem的相对距离为 -4,-3,-2,-1,0,1,2,3,4,...(详见下面测试)
# 这里根据relative_position计算了relative_position_bucket;实现见上面函数定义
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets= self.num_buckets,
)
# 以下内容无需关注
out = self.embedding.forward(allocator, relative_position_bucket)
assert out.shape == (query_len, key_len, self.num_heads)
out = out.transpose((2, 1, 0))[cupy.newaxis]
return out # (1, num_heads, key_len, query_len)
举例说明
- 取 query_len=key_len=64
- 结果1:relative_position
- 结果2:第一个图是encoder/decoder两种情形下得到relative_position_bucket;第二个图是两种情况下具体某一行的值示例(注意第二个图第一个输出是decoder情形下的值);