鹅厂面试官:Transformer为何需要位置编码?

本文基于 llama 模型的源码,学习相对位置编码的实现方法,本文不细究绝对位置编码和相对位置编码的数学原理。

大模型新人在学习中容易困惑的几个问题:

  • 为什么一定要在 transformer 中使用位置编码?

  • 相对位置编码在 llama 中是怎么实现的?

  • 大模型的超长文本预测和位置编码有什么关系?

01

为什么需要位置编码

很多初学者都会读到这样一句话:transformer 使用位置编码的原因是它不具备位置信息。大家都只把这句话当作公理,却很少思考这句话到底是什么意思?

这句话的意思是,如果没有位置编码,那么 “床前明月”、“前床明月”、“前明床月” 这几个输入,会预测出完全一样的文本。

也就是说,不管你输入的 prompt 顺序是什么,只要 prompt 的文本是相同的,那么模型 decode 的文本就只取决于 prompt 的最后一个 token。

import torch``from torch import nn``import math``   ``   ``batch = 1``dim = 10``num_head = 2``embedding = nn.Embedding(5, dim)``q_matrix = nn.Linear(dim, dim, bias=False)``k_matrix = nn.Linear(dim, dim, bias=False)``v_matrix = nn.Linear(dim, dim, bias=False)``   ``   ``x = embedding(torch.tensor([1,2,3])).unsqueeze(0)``y = embedding(torch.tensor([2,1,3])).unsqueeze(0)``   ``   ``def attention(input):`    `q = q_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)`    `k = k_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)`    `v = v_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)``   ``   `    `attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim // num_head)`    `attn_weights = nn.functional.softmax(attn_weights, dim=-1)`    `outputs = torch.matmul(attn_weights, v).transpose(1, 2).reshape(1, len([1,2,3]), dim)`    `print(outputs)``   ``   ``attention(x)``attention(y)``   

执行上面的代码会发现,虽然 x 和 y 交换了第一个 token 和第二个 token 的输入顺序,但是第三个 token 的计算结果完全没有发生改变,那么模型预测第四个 token 时,便会得到相同的结果。

如果有读者对矩阵运算感到混淆的话,可以看看下面的简单推导:

在这里插入图片描述

可以看出,当第一个 token 与第二个 token 交换顺序后,模型输出矩阵的第一维和第二维也交换了顺序,但输出的值完全没有变化。

第三个 token 的输出结果也是完全没有受到影响,这也就是前面说的:如果没有位置编码,模型 decode 的文本就只取决于 prompt 的最后一个 token

不过需要注意的是,由于 attention_mask 的存在(前置位 token 看不到后置位 token),所以即使不加位置编码,transformer 的输出还是会受到 token 的位置影响。

02

相对位置编码的实现

我们以 modeling_llama.py 的源码为例,来学习相对位置编码的实现方法。

class LlamaRotaryEmbedding(torch.nn.Module):`    `def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):`        `super().__init__()`        `inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))`        `self.register_buffer("inv_freq", inv_freq)``   ``   `        ``# Build here to make `torch.jit.trace` work.``        `self.max_seq_len_cached = max_position_embeddings`        `t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)`        `freqs = torch.einsum("i,j->ij", t, self.inv_freq)`        `# Different from paper, but it uses a different permutation in order to obtain the same calculation`        `emb = torch.cat((freqs, freqs), dim=-1)`        `self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)`        `self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)``   ``   `    `def forward(self, x, seq_len=None):`        `# x: [bs, num_attention_heads, seq_len, head_size]`        ``# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.``        `if seq_len > self.max_seq_len_cached:`            `self.max_seq_len_cached = seq_len`            `t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)`            `freqs = torch.einsum("i,j->ij", t, self.inv_freq)`            `# Different from paper, but it uses a different permutation in order to obtain the same calculation`            `emb = torch.cat((freqs, freqs), dim=-1).to(x.device)`            `self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)`            `self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)`        `return (`            `self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),`            `self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),`        `)``   ``   ``def rotate_half(x):`    `"""Rotates half the hidden dims of the input."""`    `x1 = x[..., : x.shape[-1] // 2]`    `x2 = x[..., x.shape[-1] // 2 :]`    `return torch.cat((-x2, x1), dim=-1)``   ``   ``def apply_rotary_pos_emb(q, k, cos, sin, position_ids):`    ``# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.``    `cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]`    `sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]`    `cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]`    `sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]`    `q_embed = (q * cos) + (rotate_half(q) * sin)`    `k_embed = (k * cos) + (rotate_half(k) * sin)`    `return q_embed, k_embed

相对位置编码在 attention 中的应用方法如下:

self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)``cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)``   ``   ``query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)``   ``   ``if past_key_value is not None:`    `# reuse k, v, self_attention`    `key_states = torch.cat([past_key_value[0], key_states], dim=1)`    `value_states = torch.cat([past_key_value[1], value_states], dim=1)``   

根据 value_states 矩阵的形状去调取 cos 和 sin 两个 tensor, cos 与 sin 的维度均是 batch_size * head_num * seq_len * head_dim;

利用 apply_rotary_pos_emb 去修改 query_states 和 key_states 两个 tensor,得到新的 q,k 矩阵

需要注意的是,在解码时,position_ids 的长度是和输入 token 的长度保持一致的,prompt 是 4 个 token 的话。

第一次解码时,position_ids: tensor([[0, 1, 2, 3]], device=‘cuda:0’),q 矩阵与 k 矩阵的相对位置编码信息通过 apply_rotary_pos_emb() 获得;

第二次解码时,position_ids: tensor([[4]], device=‘cuda:0’),当前 token 的相对位置编码信息通过 apply_rotary_pos_emb() 获得。

前 4 个 token 的相对位置编码信息则是通过 key_states = torch.cat([past_key_value[0], key_states], dim=1) 集成到 k 矩阵中;

……

……

以上代码的公式,均可以从苏神原文中找到。

这些代码可以从 llama 模型中剥离出来直接执行,如果感到困惑,可以像下面一样,将 apply_rotary_pos_emb() 的整个过程给 print 出来观察一下:

head_num, head_dim, kv_seq_len = 8, 20, 5``position_ids = torch.tensor([[0, 1, 2, 3, 4]])``query_states = torch.randn(1, head_dim, kv_seq_len, head_dim)``key_states = torch.randn(1, head_dim, kv_seq_len, head_dim)``value_states = torch.randn(1, head_dim, kv_seq_len, head_dim)``rotary_emb = LlamaRotaryEmbedding(head_dim)``cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)``print(cos, sin)``query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

03

位置编码与长度外推

长度外推指的是,大模型在训练的只见过长度为 X 的文本,但在实际应用时却有如下情况:

我们假设 X 的取值为 4096,那么也就意味着,模型自始至终没有见到过 pos_id >= 4096 的位置编码,进而导致模型的预测结果完全不可控。

因此,解决长度外推问题的关键便是如何让模型见到比训练文本更长的位置编码。

在这里插入图片描述

零基础如何学习大模型 AI

领取方式在文末

为什么要学习大模型?

学习大模型课程的重要性在于它能够极大地促进个人在人工智能领域的专业发展。大模型技术,如自然语言处理和图像识别,正在推动着人工智能的新发展阶段。通过学习大模型课程,可以掌握设计和实现基于大模型的应用系统所需的基本原理和技术,从而提升自己在数据处理、分析和决策制定方面的能力。此外,大模型技术在多个行业中的应用日益增加,掌握这一技术将有助于提高就业竞争力,并为未来的创新创业提供坚实的基础。

大模型典型应用场景

AI+教育:智能教学助手和自动评分系统使个性化教育成为可能。通过AI分析学生的学习数据,提供量身定制的学习方案,提高学习效果。
AI+医疗:智能诊断系统和个性化医疗方案让医疗服务更加精准高效。AI可以分析医学影像,辅助医生进行早期诊断,同时根据患者数据制定个性化治疗方案。
AI+金融:智能投顾和风险管理系统帮助投资者做出更明智的决策,并实时监控金融市场,识别潜在风险。
AI+制造:智能制造和自动化工厂提高了生产效率和质量。通过AI技术,工厂可以实现设备预测性维护,减少停机时间。

AI+零售:智能推荐系统和库存管理优化了用户体验和运营成本。AI可以分析用户行为,提供个性化商品推荐,同时优化库存,减少浪费。

AI+交通:自动驾驶和智能交通管理提升了交通安全和效率。AI技术可以实现车辆自动驾驶,并优化交通信号控制,减少拥堵。


这些案例表明,学习大模型课程不仅能够提升个人技能,还能为企业带来实际效益,推动行业创新发展。

学习资料领取

如果你对大模型感兴趣,可以看看我整合并且整理成了一份AI大模型资料包,需要的小伙伴文末免费领取哦,无偿分享!!!
vx扫描下方二维码即可
加上后会一个个给大家发

在这里插入图片描述

部分资料展示

一、 AI大模型学习路线图

整个学习分为7个阶段
在这里插入图片描述

二、AI大模型实战案例

涵盖AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,皆可用。
在这里插入图片描述

三、视频和书籍PDF合集

从入门到进阶这里都有,跟着老师学习事半功倍。
在这里插入图片描述

在这里插入图片描述

四、LLM面试题

在这里插入图片描述

如果二维码失效,可以点击下方链接,一样的哦
【CSDN大礼包】最新AI大模型资源包,这里全都有!无偿分享!!!

😝朋友们如果有需要的话,可以V扫描下方二维码联系领取~
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值