import torch
import torch.nn as nn
import torch.optim as optim
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# 1. 加载预训练模型和分词器
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name) # 加载预训练的GPT-2语言模型
tokenizer = GPT2Tokenizer.from_pretrained(model_name) # 加载对应的GPT-2分词器
# 2. 定义LoRA 插入层
class LoRALayer(nn.Module):
def __init__(self, hidden_size, r=4):
super(LoRALayer, self).__init__()
self.r = r
# 初始化两个参数:W_a和W_b,随机值乘以0.01进行缩放
self.W_a = nn.Parameter(torch.randn(hidden_size, r) * 0.01)
self.W_b = nn.Parameter(torch.randn(r, hidden_size) * 0.01)
def forward(self, x):
# 返回LoRA层的输出:x + (x @ W_a) @ W_b
return x + (x @ self.W_a) @ self.W_b
# 3. 在GPT-2模型中插入LoRA层 (只对其中一个注意力模块进行适配)
class LoraGPT2Model(GPT2LMHeadModel):
def __init__(self, config):
super(LoraGPT2Model, self).__init__(config)
# 根据模型配置插入LoRA层
self.lora_layer = LoRALayer(config.hidden_size)
def forward(self, input_ids, attention_mask=None, labels=None):
# 调用原GPT-2模型的transformer进行前向传播
outputs = self.transformer(input_ids, attention_mask=attention_mask)
hidden_states = outputs[0] # 获取隐藏状态
# 通过LoRA层
lora_output = self.lora_layer(hidden_states)
# 通过语言模型的预测头得到logits
lm_logits = self.lm_head(lora_output)
# 计算损失(如果提供了标签)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
# 返回损失和logits
return (loss, lm_logits)
# 实例化带有LoRA层的GPT-2模型
lora_model = LoraGPT2Model.from_pretrained(model_name)
# 4. 准备训练数据 (简化示例)
train_texts = ["Hello, how are you?", "I am fine, thank you!"] # 示例训练文本
# 将文本编码为模型输入格式
train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors="pt")
input_ids = train_encodings.input_ids
labels = input_ids.clone() # 标签与输入相同
# 5. 配置优化器和损失函数
optimizer = optim.Adam(lora_model.parameters(), lr=5e-5) # 使用Adam优化器,学习率为5e-5
# 6. 训练循环
num_epochs = 1 # 设置训练周期数
lora_model.train() # 将模型设置为训练模式
for epoch in range(num_epochs):
optimizer.zero_grad() # 清空梯度
loss, logits = lora_model(input_ids, labels=labels) # 前向传播计算损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}") # 打印当前周期的损失
# 7. 训练完成后,保存模型
lora_model.save_pretrained('./lora_gpt2') # 保存训练后的模型
tokenizer.save_pretrained('./lora_gpt2') # 保存分词器
最简单的Lora训练代码
最新推荐文章于 2024-10-25 11:03:23 发布