从零开始实现大语言模型(十二):文本生成策略

1. 前言

大语言模型GPTModel通过多轮推理生成连续自然语言文本,每轮推理仅生成一个token。对输入文本做tokenization,将输入文本转换成包含num_tokens个token ID的列表,并输入大语言模型GPTModel,可以得到num_tokens个维度为vocabulary_size的logits向量,第 i i i个logits向量是大语言模型根据前 i i i个token预测生成的下一个token的概率分数向量,logits向量中的第 k k k个概率分数值越大,表明大语言模型预测生成的下一个token的ID为 k k k的概率越高。使用softmax函数将最后一个logits向量归一化,使最后一个logits向量每个分量的值均介于0到1之间,所有分量之和等于1,可以得到大语言模型根据输入文本预测生成的下一个token的概率分布。

本文介绍大语言模型GPTModel预测生成连续自然语言文本的流程,以及4种从概率分布中选择下一个token的策略,并实现文本生成函数generate_text

2. 文本生成流程

大语言模型GPTModel通过多轮推理生成连续自然语言文本,每轮推理仅生成一个token。如下图所示,对输入文本Hello, I am做tokenization,将其转换成包含4个token ID的列表[15496, 11, 314, 716],并输入大语言模型GPTModel,预测生成ID为257的下一个token a。第2轮推理会将第1轮推理生成的token a添加到输入文本序列,得到包含5个token ID的列表[15496, 11, 314, 716, 257],并输入大语言模型GPTModel,预测生成ID为2746的下一个token model。依此类推,第6轮推理会将前5轮推理生成的token全部添加到输入文本序列,并将相应token ID列表输入大语言模型GPTModel,最终构造出文本序列Hello, I am a model ready to help.

图一

3. 文本生成策略

3.1 Greedy Decoding

上述文本生成流程中每轮推理会将包含num_tokens个token ID的列表输入大语言模型GPTModel。根据前文从零开始实现大语言模型(十一):构建大语言模型GPTModel可知,大语言模型GPTModel会输出num_tokens个维度为vocabulary_size的logits向量,第 i i i个logits向量是大语言模型根据前 i i i个token预测生成的下一个token的概率分数向量。logits向量中的第 k k k个概率分数值越大,表明大语言模型预测生成的下一个token的ID为 k k k的概率越高。使用softmax函数将最后一个logits向量归一化,使最后一个logits向量每个分量的值均介于0到1之间,所有分量之和等于1,可以得到大语言模型根据输入文本预测生成的下一个token的概率分布。

Greedy Decoding是一种最简单直接的从概率分布中选择下一个token的策略,其会从大语言模型每轮推理生成的下一个token的概率分布中选择最大概率值对应的index作为预测生成的下一个token的ID。如下图所示,对输入文本Hello, I am做tokenization,将相应token ID列表输入大语言模型GPTModel,并使用softmax函数将大语言模型输出的最后一个logits向量归一化,得到大语言模型根据输入文本Hello, I am预测生成的下一个token的概率分布。Greedy Decoding选择下一个token的概率分布中最大概率值对应的index257作为该轮推理预测生成的下一个token的ID。

图二

可以使用如下代码基于上述大语言模型文本生成策略Greedy Decoding实现大语言模型文本生成函数generate_text_greedy。首先使用tokenizer.encode方法对输入文本做tokenization,将输入文本text转换成包含num_tokens个token ID的列表。在每轮for循环中,使用大语言模型model推理输出num_tokens个维度为vocabulary_size的logits向量,并使用torch.softmax函数将最后一个logits向量归一化,得到下一个token的概率分布。最后使用torch.argmax函数从概率分布中选择最大概率值对应的index作为该轮推理预测生成的下一个token的ID。使用torch.cat方法将token ID列表与预测生成的下一个token的ID拼接起来,构造下一轮推理的输入。执行max_new_tokens轮推理,共生成max_new_tokens个token ID。最后使用tokenizer.decode方法将生成的token ID列表解码,得到大语言模型生成的自然语言文本:

import torch


def generate_text_greedy(
    model, start_context, max_new_tokens, context_size, tokenizer, stop_ids=None, compact_format=False
):
    model.eval()
    idx = tokenizer.encode(start_context, allowed_special=tokenizer.special_tokens_set)
    idx_tensor = torch.tensor(idx).unsqueeze(0)

    for _ in range(max_new_tokens):
        idx_cond = idx_tensor[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]
        probas = torch.softmax(logits, dim=-1)
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RuizhiHe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值