文章目录
1. Llama3 整体结构
llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。
注 1 注_1 注1: 下一节中会参照上图中 红色圆形序号 讲解各模块。
注 2 注_2 注2: llama3的RoPE算法被拆成了3个方法来实现,上图中的模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。
2. 模块详解
2.1 模块1: Embeddings
llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。
2.2 模块2: RoPE
前文中已经提及llama3的RoPE算法被拆成了3个方法来实现,模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。本小节具体按照RoPE的原始论文来讲解,主要阐述RoPE的算法原理。
2.2.1 从一个2维的例子说起 RoPE
我们知道,寻找位置编码的基本思路是 输入位置编码经过特征提取的核心算法后的值,应能反应出两个位置之间的先后顺序(这点不是必要的)和相对位置信息。(《Transformer(二)–论文理解:transformer 结构详解》 2.1节 中有简单说明),RoPE的原始论文中给出了一个数学表达,如下式:
<
f
q
(
x
m
,
m
)
,
f
k
(
x
n
,
n
)
>
=
g
(
x
m
,
x
n
,
m
−
n
)
(2.1)
<f_q(x_m,m),f_k(x_n,n)>=g(x_m,x_n,m-n) \tag{2.1}
<fq(xm,m),fk(xn,n)>=g(xm,xn,m−n)(2.1)
f
q
(
x
m
,
m
)
f_q(x_m,m)
fq(xm,m)和
f
k
(
x
n
,
n
)
f_k(x_n,n)
fk(xn,n)分别为query和key。关于
g
(
x
m
,
x
n
,
m
−
n
)
g(x_m,x_n,m-n)
g(xm,xn,m−n),我理解为输入位置变量的计算函数,和我们使用特征抽取器相关,在transformer架构里,我们一般采用点积计算attention score(见公式2.7),所以,
g
(
x
m
,
x
n
,
m
−
n
)
g(x_m,x_n,m-n)
g(xm,xn,m−n)的计算实质上应该还是计算点积(公式左边就是点积,我这里只是再啰嗦的说下为什么时点积形式)。这个函数的参数有三个,
x
m
,
x
n
x_m,x_n
xm,xn是词向量,还有一个是
m
−
n
m-n
m−n,这里之所以是
m
−
n
m-n
m−n而不是
m
m
m和
n
n
n,是因为我们的特征抽取函数(点积注意力)已经做为已知条件固定了,所以我们要在数据进行特征抽取函数前进行变换。
我们的目的就是找到一个这样的变换函数
f
{
q
,
k
}
f_{\{q,k\}}
f{q,k}能表达
f
q
(
x
m
,
m
)
f_q(x_m,m)
fq(xm,m)与
f
k
(
x
n
,
n
)
f_k(x_n,n)
fk(xn,n),使
f
q
f_q
fq与
f
k
f_k
fk做点积操作后能保留
m
−
n
m-n
m−n的信息。当然我们找到了,见公式2.2
RoPE的论文中是先从2D情况下举例说明我们找到的 f ( x ) f(x) f(x)的,如下,当 d = 2 d=2 d=2时:
f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ] (2.2) f_q(x_m,m) = (\pmb{W}_{q}x_m)e^{im\theta} \\ f_k(x_n,n) = (\pmb{W}_{k}x_n)e^{in\theta} \\ g(x_m,x_n,m-n) = Re[(\pmb{W_q}x_m)(\pmb{W}_kx_n)^{*}e^{i(m-n)\theta}] \tag{2.2} fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ](2.2)
其中
R
e
[
⋅
]
Re[ \cdot ]
Re[⋅]是复数的实部,
(
W
k
x
n
)
∗
(\pmb{W}_{k}x_n)^{*}
(Wkxn)∗表示
(
W
n
)
(\pmb{W}_n)
(Wn)的共轭复数。
θ
∈
R
\theta \in \mathbb{R}
θ∈R 是一个预设的非零常数。我们可以进一步将
f
{
q
,
k
}
f_{\{q,k\}}
f{q,k}写成乘法矩阵:
f
{
q
,
k
}
(
x
m
,
m
)
=
(
c
o
s
m
θ
−
s
i
n
m
θ
s
i
n
m
θ
c
o
s
m
θ
)
(
W
{
q
,
k
}
(
11
)
W
{
q
,
k
}
(
12
)
W
{
q
,
k
}
(
21
)
W
{
q
,
k
}
(
22
)
)
(
x
m
(
1
)
x
m
(
2
)
)
(2.3)
f_{\{q,k\}}(x_m,m)= \left( \begin{matrix} cos\ m\theta & -sin\ m\theta \\ sin\ m\theta & cos\ m\theta \\ \end{matrix} \right) \left( \begin{matrix} W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\ W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}} \\ \end{matrix} \right) \left( \begin{matrix} x^{(1)}_{m} \\ x^{(2)}_{m} \end{matrix} \right) \tag{2.3}
f{q,k}(xm,m)=(cos mθsin mθ−sin mθcos mθ)(W{q,k}(11)W{q,k}(21)W{q,k}(12)W{q,k}(22))(xm(1)xm(2))(2.3)
其中, ( x m ( 1 ) , x m ( 2 ) ) (x^{(1)}_{m},x^{(2)}_{m}) (xm(1),xm(2))是 x m x_m xm在二维坐标系中的表示。同样的, g g g也可以看作一个矩阵,因此可以在2维情况下求解公式(2.1)。
2.2.2 RoPE的一般形式
为了将我们在2D中的结果推广到任意的
x
i
∈
R
d
x_i \in \mathbb{R}^d
xi∈Rd,我们将d维空间划分为d/2个子空间,并根据内积的线性性质将它们组合起来,将
f
{
q
,
k
}
(
x
m
,
n
)
f_{\{q,k\}}(x_m,n)
f{q,k}(xm,n)转化为:
f
{
q
,
k
}
(
x
m
,
m
)
=
R
Θ
,
m
d
W
{
q
,
k
}
x
m
(2.4)
f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m \tag{2.4}
f{q,k}(xm,m)=RΘ,mdW{q,k}xm(2.4)
其中,
W
{
q
,
m
}
\pmb{W}_{\{q,m\}}
W{q,m} 表示与query和key 所对应的转换矩阵 ,
x
m
x_m
xm 为输入向量,
R
Θ
,
m
d
\pmb{R}^d_{\Theta,m}
RΘ,md为旋转矩阵,具体如下:
R
Θ
,
m
d
=
(
c
o
s
m
θ
1
−
s
i
n
m
θ
1
0
0
⋯
0
0
s
i
n
m
θ
1
c
o
s
m
θ
1
0
0
⋯
0
0
0
0
c
o
s
m
θ
2
−
s
i
n
m
θ
2
⋯
0
0
0
0
s
i
n
m
θ
2
c
o
s
m
θ
2
⋯
0
0
⋮
⋮
⋮
⋮
⋱
⋮
⋮
0
0
0
0
⋯
c
o
s
m
θ
d
/
2
−
s
i
n
m
θ
d
/
2
0
0
0
0
⋯
s
i
n
m
θ
d
/
2
c
o
s
m
θ
d
/
2
)
(2.5)
\pmb{R}^{d}_{\Theta,m}= \left( \begin{matrix} cos\ m\theta_1 & -sin\ m\theta_1 &0 &0 & \cdots &0 &0 \\ sin\ m\theta_1 & cos\ m\theta_1 &0 &0 & \cdots &0 &0 \\ 0 & 0 & cos\ m\theta_2 & -sin\ m\theta_2 & \cdots &0 &0 \\ 0 & 0 & sin\ m\theta_2 & cos\ m\theta_2 & \cdots &0 &0 \\ \vdots & \vdots &\vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 &0 &0 & \cdots & cos\ m\theta_{d/2} & -sin\ m\theta_{d/2} \\ 0 & 0 &0 &0 & \cdots & sin\ m\theta_{d/2} & cos\ m\theta_{d/2} \\ \end{matrix} \right) \tag{2.5}
RΘ,md=
cos mθ1sin mθ100⋮00−sin mθ1cos mθ100⋮0000cos mθ2sin mθ2⋮0000−sin mθ2cos mθ2⋮00⋯⋯⋯⋯⋱⋯⋯0000⋮cos mθd/2sin mθd/20000⋮−sin mθd/2cos mθd/2
(2.5)
Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } (2.6) \Theta=\{ \theta_i = 10000^{-2(i-1)/d}, i \in [1,2,...,d/2] \} \tag{2.6} Θ={θi=10000−2(i−1)/d,i∈[1,2,...,d/2]}(2.6)
2.2.3 RoPE的理解
这里我们把我们求出的
f
{
q
,
k
}
(
x
m
,
m
)
=
R
Θ
,
m
d
W
{
q
,
k
}
x
m
f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m
f{q,k}(xm,m)=RΘ,mdW{q,k}xm代入attention score的计算公式
a
m
,
n
=
exp
(
q
m
T
k
n
d
)
∑
j
=
1
N
exp
(
q
m
T
k
j
d
)
(2.7)
a_{m,n}=\frac{\exp{(\frac{q^{T}_mk_n}{\sqrt{d}})}}{\sum^N_{j=1}{\exp{(\frac{q^{T}_mk_j}{\sqrt{d}})}}} \tag{2.7}
am,n=∑j=1Nexp(dqmTkj)exp(dqmTkn)(2.7)
这里我们只需要看 q m T k m q^T_{m}k_m qmTkm即可,公式的其余部分不会改变结果形式。把公式2.4代入2.7
q m T k n = ( R Θ , m d W q x m ) T ( R Θ , n d W k x n ) = x T W q R Θ , n − m d W k x n (2.8) q^{T}_{m}k_n=(\pmb{R}^d_{\Theta,m}\pmb{W}_qx_m)^T(\pmb{R}^d_{\Theta,n}\pmb{W}_kx_n)=x^T\pmb{W}_{q}R^d_{\Theta,n-m}\pmb{W}_kx_n \tag{2.8} qmTkn=(RΘ,mdWqxm)T(RΘ,ndWkxn)=xTWqRΘ,n−mdWkxn(2.8)
其中, R Θ , n − m d = ( R Θ , m d ) T R Θ , n d \pmb{R}^d_{\Theta,n-m} = (\pmb{R}^d_{\Theta,m})^T\pmb{R}^d_{\Theta,n} RΘ,n−md=(RΘ,md)TRΘ,nd,注意 R Θ d \pmb{R}^d_{\Theta} RΘd是一个正交矩阵,这保证了位置信息在处理过程中的稳定性。此外,由于 R Θ d \pmb{R}^d_{\Theta} RΘd的稀疏性,式(2.8)的计算效率不高,作者在理论上提供了另一种实现。
2.3 模块3: Transformer Block
Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。
2.4 模块4: RMSNorm
RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
layer normalization 是根据下面的公式对
x
x
x的分布进行调整。
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
(2.9)
x = a * \frac{x - \overline{x}}{std + eps} + b \tag{2.9}
x=a∗std+epsx−x+b(2.9)
其中,
x
‾
\overline{x}
x是均值,
s
t
d
std
std是标准差,
e
p
s
eps
eps为一个很小的数,防止分母为零。
a
a
a、
b
b
b为参数,
b
b
b可以为零。
我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
x = a * \frac{x - \overline{x}}{std + eps} + b
x=a∗std+epsx−x+b进行了简化。RMSNorm的计算公式如下:
a
‾
i
=
a
i
R
M
S
(
a
)
g
i
,
w
h
e
r
e
R
M
S
(
a
)
=
1
n
Σ
i
=
1
n
a
i
2
(2.10)
\overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}} \tag{2.10}
ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2(2.10)
从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(ai−a))。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。
下面附上meta llama3中RMSNorm的源码,方便大家理解。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
2.5 模块5: Attention
llama3的attention模块主要做了4部分工作,分别是RoPE计算、分注意力分组机制实现、点积注意力计算 及 kv缓存策略实现。其中RoPE的计算在模块2中已经讲解,这里不在赘述。下文对GQA,点积注意力计算及KV缓存进行简单的讲解。
2.5.1 分组注意力机制(GQA)
llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
我们知道,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
为了清楚,这理举一个具体的例子:假设有一个token 序列
[
T
1
,
T
2
,
T
3
,
T
4
,
T
5
,
T
6
,
T
7
,
T
8
]
[T_1,T_2,T_3,T_4,T_5,T_6,T_7,T_8]
[T1,T2,T3,T4,T5,T6,T7,T8], 我们把这个token序列分成两个组来计算GQA。
-
step1: 分组
G r o u p 1 = [ T 1 , T 2 , T 3 , T 4 ] G r o u p 2 = [ T 5 , T 6 , T 7 , T 8 ] Group_1=[T1,T2,T3,T4] \\ Group_2=[T5,T6,T7,T8] Group1=[T1,T2,T3,T4]Group2=[T5,T6,T7,T8] -
step2:计算分组后的注意力
每个组内部计算注意力分数。为简单起见,我们假设我们有以下简化的注意力机制:
A t t e n t i o n S c o r e ( Q i , K i ) = Q i ⋅ K i d k Attention\ Score(Q_i,K_i) = \frac{Q_i \cdot K_i}{\sqrt{d_k}} Attention Score(Qi,Ki)=dkQi⋅Ki
其中 Q Q Q 是query, K K K 是 key, d k d_k dk 是键的维度。- 对于
G
r
o
u
p
1
Group_1
Group1:
- 计算标记 [ T 1 、 T 2 、 T 3 、 T 4 ] [T1、T2、T3、T4] [T1、T2、T3、T4] 的注意力得分。
- 这会生成一个 4×4 注意力矩阵。
- 对于
G
r
o
u
p
2
Group_2
Group2:
- 计算标记 [ T 5 、 T 6 、 T 7 、 T 8 ] [T5、T6、T7、T8] [T5、T6、T7、T8] 的注意力得分。
- 这会生成另一个 4×4 注意力矩阵。
- 对于
G
r
o
u
p
1
Group_1
Group1:
-
step3: 共享组内注意力分数
在每个组中,注意力分数是共享的。例如,第 1 组的注意力矩阵可能如下所示:
A t t e n t i o n S c o r e G r o u p 1 = [ a 11 a 12 a 13 a 14 a 21 a 22 a 23 a 24 a 31 a 32 a 33 a 34 a 41 a 42 a 43 a 44 ] Attention\ Score_{\ Group_1} = \left[ \begin{matrix} a_{11} & a_{12} & a_{13} &a_{14} \\ a_{21} & a_{22} & a_{23} &a_{24} \\ a_{31} & a_{32} & a_{33} &a_{34} \\ a_{41} & a_{42} & a_{43} &a_{44} \\ \end{matrix} \right] Attention Score Group1= a11a21a31a41a12a22a32a42a13a23a33a43a14a24a34a44
对于第二个分组:
A t t e n t i o n S c o r e G r o u p 2 = [ a 51 a 52 a 53 a 54 a 61 a 62 a 63 a 64 a 71 a 72 a 73 a 74 a 81 a 82 a 83 a 84 ] Attention\ Score_{\ Group_2} = \left[ \begin{matrix} a_{51} & a_{52} & a_{53} &a_{54} \\ a_{61} & a_{62} & a_{63} &a_{64} \\ a_{71} & a_{72} & a_{73} &a_{74} \\ a_{81} & a_{82} & a_{83} &a_{84} \\ \end{matrix} \right] Attention Score Group2= a51a61a71a81a52a62a72a82a53a63a73a83a54a64a74a84 -
step4:注意力计算
组中的每个标记根据计算出的分数关注其组中的其他标记(具体计算方法见2.5.2节)。例如, T 1 T_1 T1 将使用第 1 组注意力矩阵第一行的分数关注 T 2 T_2 T2、 T 3 T_3 T3 和 T 4 T_4 T4。 -
step5:合并结果
在计算每个组内的注意力后,我们将结果合并以形成最终的输出序列。每个标记的输出是它关注的标记值的加权和。 -
优点总结 :对查询进行分组有以下两个优点
- 降低复杂度:我们不再计算 8×8 矩阵的注意力,而是计算两个 4×4 矩阵,从而显著减少了计算量。
- 可扩展性:此方法更适合长序列,因为注意力计算随组大小而非整个序列长度二次增长。
2.5.2 注意力计算(Scaled Dot-Product Attention)
llama3 计算attention score时,使用了与《attention is all you need》一文中相同的计算方法,即点积注意力方法(Scaled Dot-Product Attention),由于Scaled Dot-Product Attention在《Transformer(二)–论文理解:transformer 结构详解》 一文中的2.2.1章节有详细的讲解,这里就不再展开。
2.5.3 KV缓存
llama3在计算 attention 时采用了kv cache策略。此策略的思想是缓存每个时间步的key和value的值,在推理阶段,由于模型是自回归模式生成文本,所以当我们对过往时间步有缓存结果时,会减少计算量,提高解码效率。
下面是llama3中Attention类的源码,大家可以参考理解
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
.
.
.
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
# 以下是Scaled Dot-Product Attention的计算
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
2.6 模块6: ADD
此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。
2.7 模块7: FFN
由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s
i
l
u
(
x
)
=
x
∗
σ
(
x
)
,
w
h
e
r
e
σ
(
x
)
i
s
t
h
e
l
o
g
i
s
t
i
c
s
i
g
m
o
i
d
.
silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid.
silu(x)=x∗σ(x), where σ(x) is the logistic sigmoid.
函数的激活曲线如下图:
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。
下面给出主要源码。
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
.
.
.
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
2.8 模块8: Linear
此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)