标题RL强化学习从小白到老鸟(二)——手撕GPT(零基础保姆级教学)
简介
帮老师带几个研究生,他们想学GPT和强化学习,正好我两者都略懂,正在研究结合两者优势创建一个全新的结构,就先手撕GPT做个教学贴吧。
GPT简介
学过GPT的可以直接跳过本节。
GPT是基于 Transformer 架构的预训练模型,它通过大规模语料库的训练,能够捕捉到复杂的语言特征和模式。这使得GPT在文本生成任务中能够产生连贯且具有一定逻辑的文本。
GPT 可以适用于多种文本生成任务,如故事生成、自动文章写作、对话生成等。其预训练的特性使得它在微调后能够迅速适应新的文本生成任务,减少了任务特定模型训练的需求。
GPT 是一个自回归模型,它在生成文本时会考虑到之前所有生成的内容,这保证了文本的连贯性和逻辑性。它每次生成一个词或一个字符,然后将其作为下一步生成的上下文的一部分,这种方式很适合逐步构建整体一致的长文本。
随着模型规模的增大,GPT表现出了显著的性能提升。例如,从GPT-2到GPT-3,模型的层数和参数量大幅增加,使其在各种语言任务上的表现都有显著提升。这证明了GPT架构对于规模扩展的高适应性。
由于GPT的零样本(zero-shot)和少样本(few-shot)学习能力,它能够在没有大量特定任务数据的情况下调整并执行多种语言任务。这是通过其在大规模数据集上的预训练实现的,使其学会了广泛的语言模式和知识。
架构(最简GPT结构)
-
多头自注意力(Multi-Head Self-Attention) :
- 在这个模块中,模型可以同时处理不同子空间的信息。这允许模型在生成每个词时考虑到句子中的多个上下文关系。
-
逐位置前馈网络(Position-wise Feed-Forward Networks) :
- 这是Transformer每一层中的一个组成部分,它对每一个位置的表示进行独立的变换,增加了模型的表达能力。
-
正弦位置编码(Sinusoidal Positional Encoding) :
- Transformer模型没有像循环神经网络那样的递归结构,因此需要位置编码来提供序列中各元素的位置信息。GPT使用的是正弦和余弦函数的组合来编码位置。
-
填充掩码(Padding Mask) :
- 在处理不等长的句子时,较短的句子需要用特殊的填充符号(如
<pad>
)填充,填充掩码用于在自注意力层忽略这些填充符号的影响。
- 在处理不等长的句子时,较短的句子需要用特殊的填充符号(如
-
后续掩码(Look-ahead Mask 或 Causal Mask) :
- 用于确保在生成当前词时只使用之前的词,这对于任何自回归模型(如GPT)都是必要的,以防止信息的未来泄露。
-
解码器层(Decoder Layer) :
- GPT模型实际上只使用了Transformer的解码器架构的修改版(没有编码器-解码器注意力),每个解码器层包括多头自注意力和前馈网络。
-
解码器(Decoder) :
- 解码器是由多个解码器层堆叠而成,整个结构被训练来预测下一个可能的词。
如何手搓一个麻雀级别的GPT
通过前面的介绍,大家应该知道了GPT的基本结构,接下来就代码实操一下
导入模块
import torch
import torch.nn as nn
import torch.optim as optim```
#这段代码很简单,就是将pytorch的网络模块和优化器模块导入,方便后续训练使用
定义模型
class SimpleGPT(nn.Module): # SimpleGPT通过继承nn.Module模块(所有神经网络的基类)。
def __init__(self, vocab_size, d_model, num_heads, num_layers, dropout=0.1):
super(SimpleGPT, self).__init__() # init初始化方法用于初始化模型的参数,如词汇表大小vocab_size,模型维度d_model,d多头注意力机制的头数num_heads,Transformer的层数num_layer以及dropout比率。
self.embedding = nn.Embedding(vocab_size, d_model) # 创建嵌入层,将词汇表中的每个词标记(通常是整数索引)映射为一个固定大小的向量。d_model是嵌入向量的维度,这个嵌入向量在训练过程会被优化。
self.pos_encoder = nn.Parameter(torch.zeros(1, 1024, d_model)) # 创建位置编码,是Transformer模型的一个重要部分,它给每个输入的词元添加一个唯一的向量,这有助于模型理解词源在序列中的位置,
#这里我们定义了一个参数,它的大小是(1,1024, d_model),其中1024是假设的最大序列长度(num_token),d_model是每个位置编码向量的维度。
# Encoder Layers
layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dropout=dropout, bitch_first=True) # 定义类一个Transformer编码器层,然后使用这个层创建一个多层的Transformer编码器,d_model是特征维度,
self.transformer_encoder = nn.TransformerEncoder(layer, num_layers=num_layers) # num_heads是多头注意力机制的头数,dropout是在训练中随机丢弃的节点比率,防止过拟合,num_layers是层数
self.out = nn.Linear(d_model, vocab_size) # 输出是一个全连接层,将编码器的输出转回词汇表大小的空间,用于预测每个位置可能得下一个词标记。
def forward(self, src, src_mask):
"""定义了模型如何从输入到输出的数据流 """
src = self.embedding(src) + self.pos_encoder[:, :src.size(1)] # 输入数据src通过嵌入层并与位置编码相加
output = self.transformer_encoder(src, src_mask) # 然后传递给编码器层,src_mask用来在多头注意力汇总遮盖不应该被模型看到的未来未知的信息。
output = self.out(output) # 通过输出层生成预测结果
return output
def generate_square_subsequent_mask(self, sz):
"""生成遮罩,生成一个三角形矩阵,用于在训练的过程中遮盖序列的未来未知,保证模型在做预测时只能利用之前的信息。这是实现自回归模型(如GPT)的关键部分"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) # 生成上三角矩阵,然后转置。
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) # 将布尔矩阵转换为浮点数并填充特定值
return mask
逐行注解,看不懂的话建议换个赛道了, 哈哈哈(开个玩笑)
初始化模型参数
# 设置参数
vocab_size = 100 # 小词库的大小
d_model = 512 # 嵌入维度
num_heads = 8 # 注意力头数
num_layers = 2 # Transformer层数
dropout = 0.1 # Dropout比率
model = SimpleGPT(vocab_size, d_model, num_heads, num_layers, dropout)
# 实例化模型类
定义优化器和损失函数
criterion = nn.CrossEntropyLoss() # 创建交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 创建Adama优化器
训练模型
# 这段代码展示了一个简单的训练循环,用于训练之前定义的SimleGPT模型。
model.train() # 将模型设置为训练模式
for epoch in range(10): # 举例训练10个epoch,每个epoch中,模型会看到所有的训练数据一次
src = torch.randint(0, vocab_size, (10, 32)) # 随机生成一些数据,数据生成是为了用于训练,这里应该替换为真实数据训练。
src_mask = model.generate_square_subsequent_mask(src.size(0))
output = model(src, src_mask)
loss = criterion(output.view(-1, vocab_size), src.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch} Loss {loss.item()}") # 打印损失值。
# 这个训练循环是机器学习训练过程的典型表示,涵盖了数据准备、模型调用、损失计算、反向传播和参数更新等关键步骤。
这个是一个测试项目,用一些随机生成的数据,验证模型的运行情况。
想要文本类数据项目的源码,关注点赞私聊,发邮箱。 前面我都是放在github,结果大家只拿代码,不点赞,生气气了╭(╯^╰)╮
RL学习的项目源码,均会放在github仓库,欢迎大家关注,本文的源码也会在后面空了(看到star了)更新到github
建议大家先看我的上一篇文章,有助于理解本文(想要你的star),star越多,更新越多,都是干货,真实项目,面试宝典