源码地址:https://gitee.com/guojialiang2023/gpt2
模型
文本生成
配置
class GenerateConfig(object):
def __init__(self,
seq_len: int,
nucleus_prob: float,
use_gpu: bool):
self.seq_len = seq_len
self.nucleus_prob = nucleus_prob
self.use_gpu = use_gpu
生成框架
import torch
import torch.nn as nn
from typing import List
class GenerationSpec(object):
def initialize(self):
pass
def construct_model(self) -> nn.Module:
raise NotImplementedError()
def encode_context(self, context: str) -> List[int]:
raise NotImplementedError()
def decode_tokens(self, tokens: List[int]) -> str:
raise NotImplementedError()
def decorate_sequence(self, sequence: torch.Tensor, offset: int
) -> torch.Tensor:
return sequence
文本生成类实现
import torch
from model import Past
from generation import GenerationSpec, GenerateConfig
from typing import List, Optional, Tuple
class Generator(object):
def __init__(self, spec: GenerationSpec, config: GenerateConfig):
self.spec = spec
self.config = config
def initialize(self, from_model: Optional[str] = None):
# Initialize generation environment and construct a model.
self.spec.initialize()
self.model = self.spec.construct_model().eval()
# Load trained model parameters.
if from_model:
ckpt = torch.load(from_model, map_location='cpu')
self.model.load_state_dict(ckpt['model'])
# Move the model to GPU device and convert the data type to half
# precision.
if self.config.use_gpu:
self.model.cuda().half()
def generate(self, context: str) -> str:
words = self.spec.encode_context(context)
current, past = words, None
while len(words) < self.config.seq_len:
# Predict the next word token from the given context.
probs, past = self._predict_probs(current, past)
next_word = self._sample_from_top_p(probs)
# Change the context to the predicted word.
words.append(next_word)
current = [next_word]
return self.spec.decode_tokens(words)
@torch.no_grad()
def _predict_probs(self,
words: List[int],
past: Optional[List[Past]] = None
) -> Tuple[torch.Tensor, List[Past]]:
x = torch.tensor(words, dtype=torch.long)
x = self.spec.decorate_sequence(
x, offset=past[0][0].size(-2) if past is not None else 0)
if self.config.use_gpu:
logits, past = self.model(x.cuda(), past)
logits = logits.cpu().float()
else:
logits, past = self.model(x, past)
return logits[-1, :].softmax(-1), past
def _sample_from_top_p(self, probs: torch.Tensor) -> int:
probs, indices = probs.sort(descending=True)
mask = probs.cumsum(-1) > self.config.nucleus_prob
mask[0] = False
probs.masked_fill_(mask, 0)
# Sample from filtered distribution.
return indices[probs.multinomial(1)[0]].item()
代码定义了一个用于文本生成的Generator
类,它使用了GPT-2模型。这个类能够基于给定的上下文生成文本。下面是对代码中关键部分的详细解释:
__init__
方法
__init__
是类的构造函数,它接受两个参数:spec
和config
。spec
(GenerationSpec
类型)包含了模型构造和上下文编码/解码的规范,这是一个抽象定义,用于处理特定于模型的逻辑。config
(GenerateConfig
类型)包含了生成过程的配置,如是否使用GPU、生成序列的长度等。
initialize
方法
initialize
方法用于初始化生成环境,构造模型,并可选地从已训练的模型中加载参数。- 如果
from_model
参数被提供,方法会加载这个模型的参数。 - 如果配置为使用GPU,模型会被移动到GPU上,并转换为半精度浮点数以提高性能。
generate
方法
generate
方法接受一个字符串类型的context
作为参数,这个上下文用作文本生成的起点。- 方法首先将上下文编码为模型能理解的形式(通常是一系列token的ID)。
- 然后,它在给定的上下文基础上循环生成文本,直到达到配置的序列长度
seq_len
。 - 在每一步中,它都会调用
_predict_probs
方法来预测下一个单词的概率分布,并通过_sample_from_top_p
方法从这个分布中采样一个单词。 - 生成的单词被添加到上下文中,用作下一次预测的输入。
_predict_probs
方法
- 这是一个私有方法,用于基于当前的单词(或单词序列)和过去的状态(如果有的话)来预测下一个单词的概率分布。
- 它接受当前的单词序列和可选的过去状态作为输入,并返回下一个单词的概率分布和更新后的状态。
- 如果配置为使用GPU,输入和输出会相应地移动到GPU或CPU上,并且输出的logits会被转换为浮点数。
_sample_from_top_p
方法
- 这个私有方法用于实现“nucleus sampling”(也称为“top-p sampling”),这是一种从概率分布中采样单词的方法,它仅考虑累积概率超过某个阈值(
nucleus_prob
)的最高概率单词。 - 通过这种方式,它有效地过滤掉了低概率的单词,减少了生成随机无关文本的可能性,同时保留了一定程度的随机性以增加文本的多样性。
文本生成代码
这段代码是一个完整的Python脚本,用于通过命令行界面生成使用GPT-2模型训练的文本。它主要包括定义GPT2GenerationSpec
类、generate_sentence_with_gpt2_model
函数和add_subparser
函数,以及如何使用argparse
库来解析命令行参数。下面是对这些关键部分的详细解释:
GPT2GenerationSpec
类
GPT2GenerationSpec
类继承自GenerationSpec
,专门用于GPT-2模型的文本生成。它通过初始化参数配置词汇表、序列长度、Transformer模型的层次、注意力头数、维度和维度增加率。initialize
方法加载词汇表并初始化分词器。construct_model
方法构建Transformer模型实例,这里的模型配置是根据初始化时传入的参数定制的。encode_context
方法将文本上下文编码为模型能够处理的token序列。decode_tokens
方法将token序列解码回文本字符串,如果遇到结束符eos_idx
,则只解码到该符号为止。
generate_sentence_with_gpt2_model
函数
- 这个函数是脚本的主要执行点,用于生成文本。它首先根据命令行参数创建
GPT2GenerationSpec
和GenerateConfig
实例。 - 接着,它初始化
Generator
实例,并可选地从文件中加载预训练的模型参数。 - 使用
while True
循环不断读取用户输入,并生成相应的文本输出。
add_subparser
函数
- 这个函数用于
argparse
库,以定义命令行参数和子命令。它允许用户通过命令行指定词汇表文件路径、模型文件路径、模型配置(如序列长度、层数、注意力头数等)和生成选项(如nucleus采样概率和是否使用GPU)。 - 它使得脚本的使用更加灵活,用户可以根据需要调整生成文本的配置。
至此全部从零开始实现GPT2已全部完成