ColumnParallelLinear
和 RowParallelLinear
是两种并行线性层,它们的主要区别在于权重矩阵的分割方式和计算过程。具体来说:
ColumnParallelLinear
-
权重矩阵分割方式:
- 权重矩阵
W
被按列(column)分割成多个子矩阵,每个子矩阵在并行设备上进行计算。 - 如果权重矩阵
W
的形状是(output_dim, input_dim)
,它会被分割成P
个子矩阵,每个子矩阵的形状是(output_dim, input_dim / P)
。
- 权重矩阵
-
计算过程:
- 输入
x
在并行设备上是完整的,不需要分割。 - 每个设备计算局部的矩阵乘法:\text{local_output}_i = x \cdot W_i^T
- 所有设备上的局部输出结果
local_output_i
汇总起来形成最终输出: \text{output} = \sum_{i=1}^{P} \text{local_output}_i
- 输入
-
优势:
- 适合输入数据较大但输出维度较小的场景。
RowParallelLinear
-
权重矩阵分割方式:
- 权重矩阵
W
被按行(row)分割成多个子矩阵,每个子矩阵在并行设备上进行计算。 - 如果权重矩阵
W
的形状是(output_dim, input_dim)
,它会被分割成P
个子矩阵,每个子矩阵的形状是(output_dim / P, input_dim)
。
- 权重矩阵
-
计算过程:
- 输入
x
在并行设备上进行分割,每个设备上的输入为x_i
,维度为(batch_size, seq_len, input_dim / P)
。 - 每个设备计算局部的矩阵乘法: \text{local_output}_i = x_i \cdot W_i^T
- 所有设备上的局部输出
local_output_i
汇总起来形成最终输出: \text{output} = \text{concatenate}(\text{local_output}_1, \text{local_output}_2, \ldots, \text{local_output}_P, \text{axis}=-1)
- 输入
-
优势:
- 适合输出维度较大但输入数据较小的场景。