transformer t5 relative position代码解读

先整体看relative_position的调用说明

The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. (如果双向=False,正相对位置是无效的。)

		We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. 
		我们对于小的相对位置使用小的buckets,对于大的相对位置使用大的buckets。
		All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
		所有相对位置大于max_distance的映射到相同的bucket,所有相对位置小于等于max_distance的映射到相同的bucket。
        This should allow for more graceful generalization to longer sequences than the model has been trained on

对于相对位置编码的综合的解释

这个设计的思路其实也很直观,就是比较邻近的位置(0~7),我们需要比较得精细一些,所以给它们都分配一个独立的位置编码,至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码,距离越远,共用的范围就可以越大,直到达到指定范围再clip。使用类似于nezha的相对位置编码,考虑i-j的相对位置内容信息。

T5Attention类别的调用

首先查看一下初始化参数的值

self.has_relative_attention_bias = False
self.relative_attention_num_buckets = 32
self.d_model = 512
self.key_value_proj_dim = 64
self.inner_dim = 512

接着进入forward调用程序部分

batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length

得到参数

batch_size = 1,seq_length = 15,real_seq_length = 15

然后经历三个dense网络层

query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

这里对应的形状参数为

query_states = 
torch.Size([1, 8, 15, 64])
......

接下来本质上是跟nezha模型一样,有个position_bias参数,然后softmax(Q*K+bias)…
这里T5中的selfattention公式与之前的selfattention公式有所区别,T5之中的selfattention公式内容为
s o f t m a x ( Q ∗ K T + p o s i t i o n b i a s ) ∗ V softmax(Q*K^{T}+position_{bias})*V softmax(QKT+positionbias)V
这里不需要除以 d k d_{k} dk
接下来计算scores的内容

scores = torch.matmul(
	query_states,key_states.transpose(3,2)
)

得到的scores的内容是

scores = (1,8,15,15)

接下来计算位置便移的position_bias内容

position_bias = self.compute_bias(real_seq_length,key_length)

进入的到compute_bias函数之中,查看它的对应计算过程

t5模型compute_bias调用过程

(感觉这里t5模型的compute_bias与nezha中的compute_bias有点类似???)
注意!!!t5的相对位置编码和nezha的相对位置编码还是不一样的!!!

def compute_bias(self, query_length, key_length):
    """Compute binned relative position bias"""
    #query_length = 15,key_length = 15
    context_position = torch.arange(
        query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
    )[:, None]
    memory_position = torch.arange(
        key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
    )[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = self._relative_position_bucket
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值