手撕BeamSearch代码,一线互联网架构师筑基必备技能之大数据开发篇

先自我介绍一下,小编浙江大学毕业,去过华为、字节跳动等大厂,目前阿里P7

深知大多数程序员,想要提升技能,往往是自己摸索成长,但自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年最新大数据全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友。
img
img
img
img
img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上大数据知识点,真正体系化!

由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、讲解视频,并且后续会持续更新

如果你需要这些资料,可以添加V获取:vip204888 (备注大数据)
img

正文

        next_beam_indices = torch.zeros((batch_size,num_beams), dtype=next_indices.dtype)

        for batch_idx in range(batch_size):
            beam_idx=0
            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
            ):
                batch_beam_idx=batch_idx*num_beams+next_index

                next_beam_scores[batch_idx, beam_idx] = next_score      #当前路径得分
                next_beam_tokens[batch_idx, beam_idx] = next_token      #当前时刻的token
                next_beam_indices[batch_idx, beam_idx] = batch_beam_idx  #先前对应的id

                beam_idx += 1

        return next_beam_scores.view(-1), next_beam_tokens.view(-1), next_beam_indices.view(-1)

    beam_scores, beam_next_tokens, beam_idx=process(input_ids,next_token_scores,next_tokens,next_indices)

    # 更新输入, 找到对应的beam_idx, 选择的tokens, 拼接为新的输入      #(batch*beam,seq_len)
    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
    cur_len = cur_len + 1
#输出
return input_ids,beam_scores

if name == ‘main’:
input_ids=torch.randint(0,100,size=(3,1))
print(input_ids)
input_ids,beam_scores=beam_search(input_ids,max_length=10,num_beams=3)
print(input_ids)


参考:transformers generate实现。


2. transformer generate() 解读



@torch.no_grad()
def generate( #模型入口
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional[“PreTrainedModel”] = None,
streamer: Optional[“BaseStreamer”] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:



10. go into different generation modes

根据不同的生产模型进行解码生产

if generation_mode == GenerationMode.ASSISTED_GENERATION:

#以beam search 为例子
elif generation_mode == GenerationMode.BEAM_SEARCH: #beam search 算法
# 11. prepare beam search scorer #参数初始化
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
#将输入进行扩展
# 12. interleave input_ids with num_beams additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search 核心,beam search 算法解码
result = self.beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
sequential=generation_config.low_memory,
**model_kwargs,
)



def beam_search(
self, input_ids, encoder_output, attention_mask, num_beams, max_length, pad_token_id: int, eos_token_id: int
):
batch_size = self.beam_scorer.batch_size #扩展前batch size

num_beams = self.beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape     #扩展后batch

assert (
    num_beams * batch_size == batch_beam_size
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."

beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
next_tokens = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device)
next_indices = torch.zeros((batch_size, num_beams), dtype=torch.long, device=input_ids.device)

past: List[torch.Tensor] = []
while cur_len < max_length:
    #生成相应
    logits, past = self._decoder_forward(input_ids, encoder_output, attention_mask, past)    #迭代输出
    next_token_logits = logits[:, -1, :]    #当前时刻输出

    # adjust tokens for Bart, *e.g.*    cur_len=1 与 max_length 输出调整
    next_token_logits = self.adjust_logits_during_generation(
        next_token_logits, cur_len=cur_len, max_length=max_length
    )
    #归一化
    next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)    #归一化

    # pre-process distribution
    next_token_scores = self.logits_processor(input_ids, next_token_scores)
    next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)   #当前概率+先前概率

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
    #取前beam 个路径
    next_token_scores, next_tokens = torch.topk(
        next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
    )

    next_indices = next_tokens // vocab_size
    next_tokens = next_tokens % vocab_size
    #获取对应路径,路径得分,对应的id   核心,不同beam search 不同点
    beam_scores, beam_next_tokens, beam_idx = self.beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,

网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。

需要这份系统化的资料的朋友,可以添加V获取:vip204888 (备注大数据)
img

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

要这份系统化的资料的朋友,可以添加V获取:vip204888 (备注大数据)**
[外链图片转存中…(img-fKExrJfJ-1713405766825)]

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值