本篇文章讲讲LLaMA的结构,已经有很多文章已经对LLaMA在一些结构上任务表现上做了一些解析,本文主要从优化的角度、实现kernel的角度解析一下LLaMA,读者事先对transformer的结构有基本认识最好。本文首发于我的公众号“AI不止算法”,文章链接在此
LLaMA简单介绍
几个月前,FB开源了LLAMA,LLAMA1包括三个参数量的模型7B、13B、65B, 证明了完全可以通过公开数据集来训练最先进的模型,而无需使用专有和不可获取的数据集,同时LLaMA-13B 在大多数benchmark优于 GPT-3,尽管大小只有后者的1/10。在更大规模上,LLaMA-65B 参数模型也与可以与Chinchilla或PaLM-540B相竞争,这是之前bloom、OPT等没有做到的。本文不谈LLaMA的预训练数据多么多么怎么样,也不谈LLaMA在各个任务上的表现如何,重点从性能优化的角度谈谈LLaMA的模型结构。
模型结构
LLaMA主体结构依然是transformer组成,和其它LLM不同的是:
- 使用RMSNorm(即Root Mean square Layer Normalization)对每个Transformer子层的input进行Pre Norm
- 使用激活函数SwiGLU
- 使用RoPE进行相对位置编码
- 使用了AdamW优化器,并使用cosine learning rate schedule (AdamW和Adam的区别我不是特别清楚,先放着不讲)
RMSNorm为layerNorm的变体,在分子分母都省去了Mean,同时少了beta参数,虽然不用再计算variance了,但我觉得Welford依然是Normlization类算子性能的最优解
# RMSNorm
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps # ε
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# RMSNorm
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm