class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
) -> None:
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size()
) #含头不含尾
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim))
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
#返回的是bool值
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index #注意这里的input_是tensor,故是在将全局索引转换为局部索引
masked_input[input_mask] = 0
#masked_input中True对应位置值设为0
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0 #并行运算时不一定有padding_idx
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
VocabParallelEmbedding
类详解
VocabParallelEmbedding
类是一个嵌入层,专门用于在词汇维度上进行并行化。它基于 torch.nn.Embedding
进行了适配,保留了默认值,同时增加了并行化的功能。
类的构造函数 __init__
python
Copy code
def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, ) -> None: super(VocabParallelEmbedding, self).__init__()
num_embeddings
:词汇表大小。embedding_dim
:嵌入维度。- 其他参数与
torch.nn.Embedding
的参数类似,用于控制嵌入行为。
python
Copy code
self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse self._weight = None
- 初始化各类参数并保存。
python
Copy code
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size() ) self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
- 通过
VocabUtility.vocab_range_from_global_vocab_size
计算当前 GPU 负责的词汇范围。 self.num_embeddings_per_partition
保存每个分区的词汇表大小。
python
Copy code
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method )
- 分配并初始化权重。每个 GPU 只分配其负责范围内的权重。
前向传播函数 forward
python
Copy code
def forward(self, input_: torch.Tensor) -> torch.Tensor: input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0
input_mask
用于标记输入中不在当前 GPU 负责范围内的词汇。masked_input
将不在负责范围内的词汇设置为 0,并将其余词汇索引减去self.vocab_start_index
。
python
Copy code
output_parallel = F.embedding( masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, )
- 使用
F.embedding
获取嵌入表示。
python
Copy code
output_parallel[input_mask, :] = 0.0
- 将不在负责范围内的词汇的嵌入表示设置为 0。
python
Copy code
output = reduce_from_model_parallel_region(output_parallel) return output
- 调用
reduce_from_model_parallel_region
将并行化的嵌入表示合并,得到最终的输出。
总结
VocabParallelEmbedding
类实现了在词汇表维度上的并行化嵌入层,通过将词汇表划分给不同的 GPU 并行处理来提升效率。在前向传播过程中,通过掩码和索引调整确保每个 GPU 只处理其负责范围内的词汇,最终通过 reduce 操作合并所有 GPU 的结果。