目录
一、核心目标
通过构建一个融合 Transformer 和 MoE 结构的模型,实现 “自然语言描述→Python 代码” 的生成。模型会根据输入的文本描述,自动选择最适合的 “专家网络” 生成代码,并通过语法校验确保输出代码的有效性。
二、模块详解
1. 环境配置与参数设置
- 库导入:导入 PyTorch(深度学习框架)、tokenizers(分词工具)、AST(语法分析)等核心库,用于模型构建、数据处理和语法校验。
- 设备配置:自动检测并使用 GPU(若可用)或 CPU 进行计算,提升训练效率。
- 超参数定义:设置模型结构和训练相关的关键参数,例如:
- 序列长度(
MAX_TEXT_LEN/MAX_CODE_LEN):限制输入文本和生成代码的最大长度。 - 模型维度(
EMBEDDING_DIM/FFN_DIM):控制嵌入层和前馈网络的维度,影响模型表达能力。 - 专家网络配置(
NUM_EXPERTS/TOP_K):设置专家数量和每次激活的专家数(MoE 核心参数)。 - 训练参数(
BATCH_SIZE/EPOCHS):控制训练批次大小和迭代轮数。
- 序列长度(
2. 数据处理模块
负责生成、加载和预处理训练数据(自然语言 - 代码对),为模型提供输入。
-
CodeDataset类:自定义数据集,实现对 “自然语言描述 - 代码” 对的处理:- 编码:使用分词器将文本和代码转换为整数序列(token ID)。
- 截断与填充:确保序列长度统一(符合
MAX_TEXT_LEN/MAX_CODE_LEN),短序列补零(padding),长序列截断。 - 特殊标记:添加
<eos>(结束标记)和<pad>(填充标记),帮助模型识别序列边界。
-
generate_sample_data函数:生成示例训练数据(因无法下载外部数据集):- 基于 15 个基础代码样本(如 “计算阶乘”“检查回文”)扩展出 500 个多样化样本。
- 通过随机修改函数名(如
add→calculate)或变量名(如a→x)增加数据多样性,模拟真实代码场景。
-
数据缓存:生成的数据保存为 JSON 文件,避免重复生成,提升二次运行效率。
3. 分词器训练
将自然语言和代码转换为模型可理解的 token 序列,是文本输入模型的前提。
train_tokenizers_from_data函数:使用 BPE(字节对编码)算法训练两个分词器:- 文本分词器:处理自然语言描述(如 “计算两个数的和”),生成适合文本的词汇表(大小 3000)。
- 代码分词器:处理 Python 代码(如
def add(a, b): return a + b),生成适合代码的词汇表(大小 5000)。
- 分词器保存与加载:训练好的分词器保存到本地,下次运行可直接加载,节省训练时间。
4. 语法校验功能
确保生成的代码符合 Python 语法规则,提升输出质量。
is_valid_python_code函数:使用 Python 内置的ast模块解析代码,检查语法是否正确:- 若解析成功,返回 “语法正确”;若失败(如缩进错误、括号不匹配),返回具体错误信息。
improve_code_syntax函数:简单修复常见语法问题(主要是缩进错误):- 根据代码结构(如
:结尾的行需要缩进)自动调整缩进层级,提升代码语法正确性。
- 根据代码结构(如
5. 模型结构(MoE+Transformer)
核心模块,实现 “自然语言→代码” 的映射,由文本编码器、专家网络、门控网络三部分组成。
-
PositionalEncoding类:为序列添加位置信息(Transformer 必备组件):- 通过正弦 / 余弦函数生成位置编码,使模型感知 token 在序列中的位置(如 “a+b” 中 “a” 在 “+” 之前)。
-
TextEncoder类:自然语言编码器:- 将文本 token 序列转换为嵌入向量,结合位置编码后,通过均值池化得到 “句子级特征向量”,用于后续门控网络选择专家。
-
Expert类:专家网络(每个专家都是一个 Transformer 解码器):- 每个专家负责处理特定类型的代码生成任务(如数学运算、字符串处理)。
- 输入代码序列,通过 Transformer 解码器输出下一个 token 的预测分布(用于生成代码)。
-
Gating类:门控网络(MoE 的核心):- 接收文本编码器的特征向量,计算每个专家的权重(概率)。
- 仅激活权重最高的
TOP_K个专家(稀疏激活),减少计算量的同时提升模型专注度。
-
MoECodeGenerator类:整合上述组件的完整模型:- 输入文本和代码序列,通过文本编码器生成特征→门控网络选择专家→专家输出预测→加权融合专家结果,最终输出代码 token 的预测分布。
6. 损失函数与训练逻辑
负责模型参数优化,确保模型能学习到 “自然语言→代码” 的映射规律。
-
损失函数:
- 交叉熵损失(
ce_loss):衡量模型预测的代码 token 与真实 token 的差异,确保生成准确性。 - 负载均衡损失(
lb_loss):通过 MSE 损失使各专家的平均权重接近均匀分布,避免部分专家被过度使用而其他专家闲置。 - 总损失:
ce_loss + 0.1 * lb_loss(兼顾准确性和专家均衡性)。
- 交叉熵损失(
-
train函数:模型训练过程:- 批量输入文本和代码,前向传播计算预测结果和损失。
- 反向传播更新模型参数,最小化损失。
-
evaluate函数:模型验证过程:- 在验证集上计算损失(不更新参数),监控模型是否过拟合(如验证损失上升但训练损失下降)。
7. 代码生成函数
基于训练好的模型,根据自然语言描述生成代码,并确保语法正确。
generate_valid_code函数:- 编码输入文本:将自然语言描述转换为模型可处理的 token 序列。
- 自回归生成:从
<eos>开始,逐步预测下一个 token,直到生成<eos>或达到最大长度。 - 采样策略:使用核采样(top-p)和温度调整控制生成多样性(温度越低,生成越保守;top-p 越小,过滤低概率 token 越严格)。
- 多轮尝试与校验:最多尝试 3 次生成,若某次生成的代码通过 AST 语法校验,则返回该代码;否则返回最后一次生成的结果。
8. 主程序流程
串联上述所有模块,实现从数据准备到模型训练再到代码生成的完整流程:
- 数据准备:加载缓存数据或生成新数据(500 条 “文本 - 代码” 对)。
- 分词器准备:加载已训练的分词器或基于当前数据训练新分词器。
- 数据集划分:将数据按 9:1 划分为训练集(用于模型学习)和验证集(用于监控性能)。
- 模型初始化:创建 MoE 模型并设置关键参数(如词汇表大小、专家数量)。
- 模型训练:执行 15 轮训练,每 5 轮保存一次模型,绘制训练 / 验证损失曲线(直观展示模型收敛情况)。
- 代码生成示例:使用训练好的模型对中英文查询(如 “计算阶乘”“Check if a number is even”)生成代码,并输出结果及语法校验状态。
三、基于混合专家模型的代码生成的Python代码完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import ast
import json
import zipfile
from pathlib import Path
from sklearn.model_selection import train_test_split
# 配置中文字体
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# ----------------------------
# 超参数设置
# ----------------------------
MAX_TEXT_LEN = 128
MAX_CODE_LEN = 256
EMBEDDING_DIM = 256 # 减小维度以便在CPU上运行
FFN_DIM = 512
NUM_HEADS = 4
NUM_LAYERS = 2
NUM_EXPERTS = 4
TOP_K = 2
BATCH_SIZE = 8 # 减小批次大小
EPOCHS = 15 # 减少训练轮数
LEARNING_RATE = 3e-4
# ----------------------------
# 数据处理与分词器
# ----------------------------
class CodeDataset(Dataset):
"""自然语言-代码对数据集"""
def __init__(self, data, text_tokenizer, code_tokenizer):
self.data = data
self.text_tokenizer = text_tokenizer
self.code_tokenizer = code_tokenizer
self.pad_id = code_tokenizer.token_to_id('<pad>')
self.eos_id = code_tokenizer.token_to_id('<eos>')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text, code = self.data[idx]
# 自然语言编码
text_enc = self.text_tokenizer.encode(text).ids
text_enc = text_enc[:MAX_TEXT_LEN - 1] + [self.eos_id]
text_enc += [self.pad_id] * (MAX_TEXT_LEN - len(text_enc))
# 代码编码
code_enc = self.code_tokenizer.encode(code).ids
code_enc = [self.eos_id] + code_enc[:MAX_CODE_LEN - 2] + [self.eos_id]
code_enc += [self.pad_id] * (MAX_CODE_LEN - len(code_enc))
return (
torch.tensor(text_enc, dtype=torch.long),
torch.tensor(code_enc, dtype=torch.long)
)
def generate_sample_data(num_samples=500):
"""生成示例数据"""
base_samples = [
"计算两个数的和\tdef add(a, b):\n return a + b",
"计算斐波那契数列第n项\tdef fibonacci(n):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n else:\n return fibonacci(n-1) + fibonacci(n-2)",
"判断一个数是否为质数\tdef is_prime(n):\n if n <= 1:\n return False\n for i in range(2, int(n**0.5)+1):\n if n % i == 0:\n return False\n return True",
"将列表中的元素去重\tdef unique_list(lst):\n return list(set(lst))",
"计算列表的平均值\tdef average(lst):\n return sum(lst) / len(lst) if lst else 0",
"将摄氏度转换为华氏度\tdef celsius_to_fahrenheit(c):\n return c * 9/5 + 32",
"统计字符串中单词的数量\tdef word_count(s):\n return len(s.split())",
"求两个数的最大公约数\tdef gcd(a, b):\n while b:\n a, b = b, a % b\n return a",
"生成指定范围的随机整数\timport random\ndef random_int(min_val, max_val):\n return random.randint(min_val, max_val)",
"检查字符串是否为回文\tdef is_palindrome(s):\n s = s.lower().replace(' ', '')\n return s == s[::-1]",
"将列表元素从小到大排序\tdef sort_list(lst):\n return sorted(lst)",
"计算阶乘\tdef factorial(n):\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)",
"求列表中所有元素的和\tdef sum_list(lst):\n return sum(lst)",
"将字典的值转换为列表\tdef dict_values_to_list(d):\n return list(d.values())",
"检查列表是否为空\tdef is_empty(lst):\n return len(lst) == 0"
]
# 样本
samples = base_samples.copy()
while len(samples) < num_samples:
# 随机选择一个基础样本并稍作修改
import random
base = random.choice(base_samples)
text, code = base.split('\t', 1)
# 随机修改函数名或变量名
func_names = ["calculate", "compute", "get", "find", "determine"]
var_names = ["x", "y", "val", "num", "a", "b", "i", "j"]
if random.random() < 0.3:
# 修改函数名
old_func = code.split('def ')[1].split('(')[0]
new_func = random.choice(func_names)
code = code.replace(f'def {old_func}', f'def {new_func}')
if random.random() < 0.2:
# 修改变量名
old_var = random.choice(var_names)
new_var = random.choice([v for v in var_names if v != old_var])
code = code.replace(old_var, new_var)
samples.append(f"{text}\t{code}")
return [s.split('\t', 1) for s in samples]
def train_tokenizers_from_data(text_data, code_data, save_dir):
"""从数据训练BPE分词器"""
Path(save_dir).mkdir(exist_ok=True)
# 1. 文本分词器(自然语言)
text_tokenizer = Tokenizer(BPE(unk_token='<unk>'))
text_tokenizer.pre_tokenizer = Whitespace()
text_trainer = BpeTrainer(
special_tokens=['<pad>', '<unk>', '<eos>'],
vocab_size=3000
)
text_tokenizer.train_from_iterator(text_data, text_trainer)
text_tokenizer.save(os.path.join(save_dir, 'text_tokenizer.json'))
# 2. 代码分词器
code_tokenizer = Tokenizer(BPE(unk_token='<unk>'))
code_tokenizer.pre_tokenizer = Whitespace()
code_trainer = BpeTrainer(
special_tokens=['<pad>', '<unk>', '<eos>'],
vocab_size=5000
)
code_tokenizer.train_from_iterator(code_data, code_trainer)
code_tokenizer.save(os.path.join(save_dir, 'code_tokenizer.json'))
return text_tokenizer, code_tokenizer
# ----------------------------
# AST语法校验功能
# ----------------------------
def is_valid_python_code(code):
"""使用AST模块检查Python代码的语法有效性"""
try:
ast.parse(code)
return True, "语法正确"
except SyntaxError as e:
return False, f"语法错误: {str(e)}"
except Exception as e:
return False, f"解析错误: {str(e)}"
def improve_code_syntax(code):
"""简单修复常见的代码语法问题"""
lines = code.split('\n')
fixed_lines = []
indent_level = 0
indent_step = 4 # 使用4个空格缩进
for line in lines:
stripped = line.strip()
# 减少缩进的情况
if stripped.endswith(':') and not stripped.startswith(('elif', 'else')):
fixed_lines.append(' ' * indent_level + stripped)
indent_level += indent_step
elif stripped.startswith(('return', 'break', 'continue')):
fixed_lines.append(' ' * max(0, indent_level - indent_step) + stripped)
else:
fixed_lines.append(' ' * indent_level + stripped)
return '\n'.join(fixed_lines)
# ----------------------------
# Transformer-MoE模型定义
# ----------------------------
class PositionalEncoding(nn.Module):
"""位置编码(Transformer必备组件)"""
def __init__(self, d_model, max_len=5000):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""x: (seq_len, batch_size, d_model)"""
x = x + self.pe[:x.size(0)]
return x
class TextEncoder(nn.Module):
"""自然语言编码器"""
def __init__(self, vocab_size, embed_dim, max_len):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.pos_encoder = PositionalEncoding(embed_dim, max_len)
self.fc = nn.Linear(embed_dim, embed_dim)
self.activation = nn.Tanh()
def forward(self, x):
x = self.embedding(x).permute(1, 0, 2) # (seq_len, batch_size, embed_dim)
x = self.pos_encoder(x)
x = x.permute(1, 0, 2) # (batch_size, seq_len, embed_dim)
cls_feat = x.mean(dim=1) # 取序列均值作为句子特征
return self.activation(self.fc(cls_feat))
class Expert(nn.Module):
"""Transformer解码器专家网络"""
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ffn_dim, max_len):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.pos_encoder = PositionalEncoding(embed_dim, max_len)
decoder_layers = TransformerDecoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=ffn_dim,
batch_first=True
)
self.transformer_decoder = TransformerDecoder(decoder_layers, num_layers=num_layers)
self.fc_out = nn.Linear(embed_dim, vocab_size)
def forward(self, tgt, memory=None, tgt_mask=None, tgt_pad_mask=None):
tgt_embed = self.embedding(tgt) # (batch_size, tgt_len, embed_dim)
tgt_embed = self.pos_encoder(tgt_embed.permute(1, 0, 2)).permute(1, 0, 2)
output = self.transformer_decoder(
tgt=tgt_embed,
memory=tgt_embed if memory is None else memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_pad_mask
)
return self.fc_out(output)
class Gating(nn.Module):
"""门控网络"""
def __init__(self, input_dim, num_experts, top_k):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.ReLU(),
nn.Linear(input_dim, num_experts)
)
def forward(self, x):
raw_weights = self.gate(x)
weights = F.softmax(raw_weights, dim=1)
# 稀疏激活:仅保留top-k权重
if self.top_k is not None and self.top_k < self.num_experts:
top_k_values, top_k_indices = torch.topk(weights, self.top_k, dim=1)
mask = torch.zeros_like(weights)
mask.scatter_(1, top_k_indices, 1)
weights = weights * mask
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
return weights
class MoECodeGenerator(nn.Module):
"""MoE代码生成器"""
def __init__(self, text_vocab_size, code_vocab_size, embed_dim, num_heads,
num_layers, ffn_dim, num_experts, top_k, max_text_len, max_code_len):
super().__init__()
self.text_encoder = TextEncoder(text_vocab_size, embed_dim, max_text_len)
self.experts = nn.ModuleList([
Expert(code_vocab_size, embed_dim, num_heads, num_layers, ffn_dim, max_code_len)
for _ in range(num_experts)
])
self.gating = Gating(embed_dim, num_experts, top_k)
self.code_vocab_size = code_vocab_size
self.pad_id = 0
self.eos_id = None
def set_eos_id(self, eos_id):
self.eos_id = eos_id
def generate_mask(self, size):
"""生成Transformer掩码"""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, text, code):
batch_size, code_len = code.size()
text_feat = self.text_encoder(text)
weights = self.gating(text_feat)
tgt_mask = self.generate_mask(code_len).to(device)
tgt_pad_mask = (code == self.pad_id).to(device)
expert_logits = []
for expert in self.experts:
logits = expert(code, tgt_mask=tgt_mask, tgt_pad_mask=tgt_pad_mask)
expert_logits.append(logits[:, :-1, :])
expert_logits = torch.stack(expert_logits, dim=1)
weights = weights.unsqueeze(-1).unsqueeze(-1)
logits = torch.sum(expert_logits * weights, dim=1)
return logits, weights.squeeze(-1).squeeze(-1)
# ----------------------------
# 损失函数与训练逻辑
# ----------------------------
def load_balancing_loss(gating_weights):
"""负载均衡损失"""
expert_mean = gating_weights.mean(dim=0)
target = torch.ones_like(expert_mean) / NUM_EXPERTS
return F.mse_loss(expert_mean, target)
def train(model, loader, optimizer, criterion, lambda_lb):
model.train()
total_loss = 0.0
for text, code in tqdm(loader, desc="训练"):
text = text.to(device)
code = code.to(device)
optimizer.zero_grad()
logits, weights = model(text, code)
target = code[:, 1:]
ce_loss = criterion(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
lb_loss = load_balancing_loss(weights)
loss = ce_loss + lambda_lb * lb_loss
loss.backward()
optimizer.step()
total_loss += loss.item() * text.size(0)
return total_loss / len(loader.dataset)
def evaluate(model, loader, criterion):
model.eval()
total_loss = 0.0
with torch.no_grad():
for text, code in tqdm(loader, desc="评估"):
text = text.to(device)
code = code.to(device)
logits, _ = model(text, code)
target = code[:, 1:]
loss = criterion(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
total_loss += loss.item() * text.size(0)
return total_loss / len(loader.dataset)
# ----------------------------
# 代码生成函数
# ----------------------------
def generate_valid_code(model, text, text_tokenizer, code_tokenizer,
max_attempts=3, max_len=256, temperature=0.7, top_p=0.9):
"""生成并验证代码"""
model.eval()
pad_id = code_tokenizer.token_to_id('<pad>')
eos_id = code_tokenizer.token_to_id('<eos>')
# 编码自然语言
text_enc = text_tokenizer.encode(text).ids
text_enc = text_enc[:MAX_TEXT_LEN - 1] + [eos_id]
text_enc += [pad_id] * (MAX_TEXT_LEN - len(text_enc))
text_tensor = torch.tensor(text_enc, dtype=torch.long).unsqueeze(0).to(device)
for attempt in range(max_attempts):
# 初始化代码序列
code_seq = [eos_id]
code_tensor = torch.tensor(code_seq, dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
for _ in range(max_len):
current_len = code_tensor.size(1)
tgt_mask = model.generate_mask(current_len).to(device)
tgt_pad_mask = (code_tensor == pad_id).to(device)
# 获取专家权重
text_feat = model.text_encoder(text_tensor)
weights = model.gating(text_feat)
# 专家预测
expert_logits = []
for expert in model.experts:
logits = expert(code_tensor, tgt_mask=tgt_mask, tgt_pad_mask=tgt_pad_mask)
expert_logits.append(logits[:, -1, :])
# 加权组合
expert_logits = torch.stack(expert_logits, dim=1)
weights = weights.unsqueeze(-1)
logits = torch.sum(expert_logits * weights, dim=1)
# 温度调整
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
# 核采样
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
sorted_probs[mask] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
# 采样下一个token
next_idx = torch.multinomial(sorted_probs, num_samples=1).item()
next_idx = sorted_indices[0, next_idx].item()
if next_idx == eos_id:
break
code_seq.append(next_idx)
code_tensor = torch.cat([code_tensor, torch.tensor([[next_idx]], device=device)], dim=1)
# 解码并检查语法
generated_code = code_tokenizer.decode(code_seq[1:])
improved_code = improve_code_syntax(generated_code)
is_valid, msg = is_valid_python_code(improved_code)
if is_valid:
return improved_code, attempt + 1, "语法正确"
else:
print(f"生成尝试 {attempt + 1} 失败: {msg}")
return improved_code, max_attempts, "达到最大尝试次数,返回可能存在语法错误的代码"
# ----------------------------
# 主程序
# ----------------------------
if __name__ == "__main__":
# 数据和分词器路径
data_dir = "local_data"
tokenizer_dir = "local_tokenizers"
cache_path = "local_data_cache.json"
# 确保数据目录存在
Path(data_dir).mkdir(exist_ok=True)
# 加载或生成数据(使用本地生成的数据替代下载)
if os.path.exists(cache_path):
print("加载缓存的数据...")
with open(cache_path, 'r', encoding='utf-8') as f:
all_data = json.load(f)
else:
print("生成本地示例数据...")
all_data = generate_sample_data(num_samples=500) # 生成500个样本
# 缓存数据
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump(all_data, f, ensure_ascii=False)
print(f"加载了 {len(all_data)} 条代码-注释对")
# 加载或训练分词器
if os.path.exists(os.path.join(tokenizer_dir, 'text_tokenizer.json')):
print("加载已训练的分词器...")
text_tokenizer = Tokenizer.from_file(os.path.join(tokenizer_dir, 'text_tokenizer.json'))
code_tokenizer = Tokenizer.from_file(os.path.join(tokenizer_dir, 'code_tokenizer.json'))
else:
print("正在训练分词器...")
text_data = [item[0] for item in all_data]
code_data = [item[1] for item in all_data]
text_tokenizer, code_tokenizer = train_tokenizers_from_data(text_data, code_data, tokenizer_dir)
# 划分训练集和验证集
train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
# 构建数据集
train_dataset = CodeDataset(train_data, text_tokenizer, code_tokenizer)
val_dataset = CodeDataset(val_data, text_tokenizer, code_tokenizer)
code_vocab_size = code_tokenizer.get_vocab_size()
text_vocab_size = text_tokenizer.get_vocab_size()
print(f"文本词汇表大小: {text_vocab_size}, 代码词汇表大小: {code_vocab_size}")
# 数据加载器
def collate_fn(batch):
texts, codes = zip(*batch)
return torch.stack(texts), torch.stack(codes)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=collate_fn)
# 初始化模型
model = MoECodeGenerator(
text_vocab_size=text_vocab_size,
code_vocab_size=code_vocab_size,
embed_dim=EMBEDDING_DIM,
num_heads=NUM_HEADS,
num_layers=NUM_LAYERS,
ffn_dim=FFN_DIM,
num_experts=NUM_EXPERTS,
top_k=TOP_K,
max_text_len=MAX_TEXT_LEN,
max_code_len=MAX_CODE_LEN
).to(device)
model.set_eos_id(code_tokenizer.token_to_id('<eos>'))
# 损失函数与优化器
criterion = nn.CrossEntropyLoss(ignore_index=code_tokenizer.token_to_id('<pad>'))
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
lambda_lb = 0.1
# 训练模型
train_losses, val_losses = [], []
print("\n开始训练模型...")
for epoch in range(EPOCHS):
print(f"\nEpoch {epoch + 1}/{EPOCHS}")
train_loss = train(model, train_loader, optimizer, criterion, lambda_lb)
val_loss = evaluate(model, val_loader, criterion)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}")
# 每5个epoch保存一次模型
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="训练损失")
plt.plot(val_losses, label="验证损失")
plt.xlabel("Epoch")
plt.ylabel("损失")
plt.title("训练与验证损失曲线")
plt.legend()
plt.savefig("loss_curve.png")
plt.show()
# 生成代码示例
print("\n===== 代码生成示例 =====")
test_queries = [
"Write a Python function to calculate the factorial of a number",
"Python function to sort a list of numbers",
"Implement a simple addition function",
"Check if a number is even or odd",
"Create a function that returns the sum of a list"
]
# 中文查询示例
chinese_queries = [
"写一个Python函数来计算两个矩阵的乘积",
"Python函数,将列表中的元素去重并保持原有顺序",
"实现一个Python装饰器,用于计算函数执行时间",
"写一个函数来验证电子邮件地址的格式是否正确",
"创建一个简单的Python类来表示栈并实现基本操作"
]
# 测试英文查询
for query in test_queries:
code, attempts, msg = generate_valid_code(
model, query, text_tokenizer, code_tokenizer,
max_attempts=3, max_len=200, temperature=0.6, top_p=0.9
)
print(f"查询: {query}")
print(f"状态: {msg} (尝试了 {attempts} 次)")
print(f"生成代码:\n{code}\n" + "-" * 80)
# 测试中文查询
for query in chinese_queries:
code, attempts, msg = generate_valid_code(
model, query, text_tokenizer, code_tokenizer,
max_attempts=3, max_len=200, temperature=0.6, top_p=0.9
)
print(f"查询: {query}")
print(f"状态: {msg} (尝试了 {attempts} 次)")
print(f"生成代码:\n{code}\n" + "-" * 80)
四、程序运行结果展示
使用设备: cpu
加载缓存的数据...
加载了 500 条代码-注释对
加载已训练的分词器...
文本词汇表大小: 177, 代码词汇表大小: 339
开始训练模型...
Epoch 1/15
训练: 100%|██████████| 57/57 [00:54<00:00, 1.04it/s]
评估: 100%|██████████| 7/7 [00:00<00:00, 8.88it/s]
训练损失: 2.7673, 验证损失: 0.9751
Epoch 2/15
训练: 100%|██████████| 57/57 [00:52<00:00, 1.09it/s]
评估: 100%|██████████| 7/7 [00:00<00:00, 9.65it/s]
训练损失: 0.6274, 验证损失: 0.3590
Epoch 3/15
训练: 100%|██████████| 57/57 [00:52<00:00, 1.08it/s]
评估: 100%|██████████| 7/7 [00:00<00:00, 9.26it/s]
训练损失: 0.3264, 验证损失: 0.2405
Epoch 4/15
训练: 100%|██████████| 57/57 [00:52<00:00, 1.09it/s]
评估: 100%|██████████| 7/7 [00:00<00:00, 9.16it/s]
训练损失: 0.2284, 验证损失: 0.2034
Epoch 5/15
训练: 100%|██████████| 57/57 [00:55<00:00, 1.03it/s]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.68it/s]
训练损失: 0.1878, 验证损失: 0.1651
Epoch 6/15
训练: 100%|██████████| 57/57 [01:07<00:00, 1.18s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 9.16it/s]
训练损失: 0.1471, 验证损失: 0.1412
Epoch 7/15
训练: 100%|██████████| 57/57 [01:05<00:00, 1.15s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.06it/s]
训练损失: 0.1239, 验证损失: 0.1285
Epoch 8/15
训练: 100%|██████████| 57/57 [01:06<00:00, 1.18s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.62it/s]
训练损失: 0.1042, 验证损失: 0.1154
Epoch 9/15
训练: 100%|██████████| 57/57 [01:07<00:00, 1.19s/it]
评估: 100%|██████████| 7/7 [00:01<00:00, 6.97it/s]
训练损失: 0.0835, 验证损失: 0.0967
Epoch 10/15
训练: 100%|██████████| 57/57 [01:10<00:00, 1.24s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.70it/s]
训练损失: 0.0670, 验证损失: 0.0854
Epoch 11/15
训练: 100%|██████████| 57/57 [01:08<00:00, 1.20s/it]
评估: 100%|██████████| 7/7 [00:01<00:00, 6.78it/s]
训练损失: 0.0582, 验证损失: 0.0839
Epoch 12/15
训练: 100%|██████████| 57/57 [01:10<00:00, 1.25s/it]
评估: 100%|██████████| 7/7 [00:01<00:00, 6.64it/s]
训练损失: 0.0529, 验证损失: 0.0848
Epoch 13/15
训练: 100%|██████████| 57/57 [01:11<00:00, 1.25s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.44it/s]
训练损失: 0.0505, 验证损失: 0.0816
Epoch 14/15
训练: 100%|██████████| 57/57 [01:13<00:00, 1.29s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.64it/s]
训练损失: 0.0424, 验证损失: 0.0809
Epoch 15/15
训练: 100%|██████████| 57/57 [01:10<00:00, 1.23s/it]
评估: 100%|██████████| 7/7 [00:00<00:00, 7.59it/s]
训练损失: 0.0467, 验证损失: 0.0972

===== 代码生成示例 =====
查询: Write a Python function to calculate the factorial of a number
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: Python function to sort a list of numbers
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: Implement a simple addition function
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: Check if a number is even or odd
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: Create a function that returns the sum of a list
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: 写一个Python函数来计算两个矩阵的乘积
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: Python函数,将列表中的元素去重并保持原有顺序
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: 实现一个Python装饰器,用于计算函数执行时间
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: 写一个函数来验证电子邮件地址的格式是否正确
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: 创建一个简单的Python类来表示栈并实现基本操作
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
五、总结
本文提出了一种基于Transformer和混合专家(MoE)结构的自然语言到Python代码生成模型。该模型通过文本编码器提取自然语言特征,门控网络选择最适合的专家网络生成代码,并利用语法校验确保输出正确性。实现流程包括:
(1)构建500条多样化训练数据;
(2)训练专用分词器处理自然语言和代码;
(3)设计MoE架构(4个专家)实现任务特定处理;
(4)结合交叉熵和负载均衡损失优化模型;
(5)实现自动语法校验和简单修复功能。
实验表明模型能生成符合语法的代码,但存在功能单一问题,需进一步优化专家网络能力和训练数据质量。

1万+

被折叠的 条评论
为什么被折叠?



