自然语言处理实战——基于混合专家模型的代码生成

部署运行你感兴趣的模型镜像

目录

一、核心目标

二、模块详解

1. 环境配置与参数设置

2. 数据处理模块

3. 分词器训练

4. 语法校验功能

5. 模型结构(MoE+Transformer)

6. 损失函数与训练逻辑

7. 代码生成函数

8. 主程序流程

三、基于混合专家模型的代码生成的Python代码完整实现

四、程序运行结果展示

五、总结


一、核心目标

通过构建一个融合 Transformer 和 MoE 结构的模型,实现 “自然语言描述→Python 代码” 的生成。模型会根据输入的文本描述,自动选择最适合的 “专家网络” 生成代码,并通过语法校验确保输出代码的有效性。

二、模块详解

1. 环境配置与参数设置
  • 库导入:导入 PyTorch(深度学习框架)、tokenizers(分词工具)、AST(语法分析)等核心库,用于模型构建、数据处理和语法校验。
  • 设备配置:自动检测并使用 GPU(若可用)或 CPU 进行计算,提升训练效率。
  • 超参数定义:设置模型结构和训练相关的关键参数,例如:
    • 序列长度(MAX_TEXT_LEN/MAX_CODE_LEN):限制输入文本和生成代码的最大长度。
    • 模型维度(EMBEDDING_DIM/FFN_DIM):控制嵌入层和前馈网络的维度,影响模型表达能力。
    • 专家网络配置(NUM_EXPERTS/TOP_K):设置专家数量和每次激活的专家数(MoE 核心参数)。
    • 训练参数(BATCH_SIZE/EPOCHS):控制训练批次大小和迭代轮数。
2. 数据处理模块

负责生成、加载和预处理训练数据(自然语言 - 代码对),为模型提供输入。

  • CodeDataset:自定义数据集,实现对 “自然语言描述 - 代码” 对的处理:

    • 编码:使用分词器将文本和代码转换为整数序列(token ID)。
    • 截断与填充:确保序列长度统一(符合MAX_TEXT_LEN/MAX_CODE_LEN),短序列补零(padding),长序列截断。
    • 特殊标记:添加<eos>(结束标记)和<pad>(填充标记),帮助模型识别序列边界。
  • generate_sample_data函数:生成示例训练数据(因无法下载外部数据集):

    • 基于 15 个基础代码样本(如 “计算阶乘”“检查回文”)扩展出 500 个多样化样本。
    • 通过随机修改函数名(如addcalculate)或变量名(如ax)增加数据多样性,模拟真实代码场景。
  • 数据缓存:生成的数据保存为 JSON 文件,避免重复生成,提升二次运行效率。

3. 分词器训练

将自然语言和代码转换为模型可理解的 token 序列,是文本输入模型的前提。

  • train_tokenizers_from_data函数:使用 BPE(字节对编码)算法训练两个分词器:
    • 文本分词器:处理自然语言描述(如 “计算两个数的和”),生成适合文本的词汇表(大小 3000)。
    • 代码分词器:处理 Python 代码(如def add(a, b): return a + b),生成适合代码的词汇表(大小 5000)。
  • 分词器保存与加载:训练好的分词器保存到本地,下次运行可直接加载,节省训练时间。
4. 语法校验功能

确保生成的代码符合 Python 语法规则,提升输出质量。

  • is_valid_python_code函数:使用 Python 内置的ast模块解析代码,检查语法是否正确:
    • 若解析成功,返回 “语法正确”;若失败(如缩进错误、括号不匹配),返回具体错误信息。
  • improve_code_syntax函数:简单修复常见语法问题(主要是缩进错误):
    • 根据代码结构(如:结尾的行需要缩进)自动调整缩进层级,提升代码语法正确性。
5. 模型结构(MoE+Transformer)

核心模块,实现 “自然语言→代码” 的映射,由文本编码器、专家网络、门控网络三部分组成。

  • PositionalEncoding:为序列添加位置信息(Transformer 必备组件):

    • 通过正弦 / 余弦函数生成位置编码,使模型感知 token 在序列中的位置(如 “a+b” 中 “a” 在 “+” 之前)。
  • TextEncoder:自然语言编码器:

    • 将文本 token 序列转换为嵌入向量,结合位置编码后,通过均值池化得到 “句子级特征向量”,用于后续门控网络选择专家。
  • Expert:专家网络(每个专家都是一个 Transformer 解码器):

    • 每个专家负责处理特定类型的代码生成任务(如数学运算、字符串处理)。
    • 输入代码序列,通过 Transformer 解码器输出下一个 token 的预测分布(用于生成代码)。
  • Gating:门控网络(MoE 的核心):

    • 接收文本编码器的特征向量,计算每个专家的权重(概率)。
    • 仅激活权重最高的TOP_K个专家(稀疏激活),减少计算量的同时提升模型专注度。
  • MoECodeGenerator:整合上述组件的完整模型:

    • 输入文本和代码序列,通过文本编码器生成特征→门控网络选择专家→专家输出预测→加权融合专家结果,最终输出代码 token 的预测分布。
6. 损失函数与训练逻辑

负责模型参数优化,确保模型能学习到 “自然语言→代码” 的映射规律。

  • 损失函数

    • 交叉熵损失(ce_loss):衡量模型预测的代码 token 与真实 token 的差异,确保生成准确性。
    • 负载均衡损失(lb_loss):通过 MSE 损失使各专家的平均权重接近均匀分布,避免部分专家被过度使用而其他专家闲置。
    • 总损失:ce_loss + 0.1 * lb_loss(兼顾准确性和专家均衡性)。
  • train函数:模型训练过程:

    • 批量输入文本和代码,前向传播计算预测结果和损失。
    • 反向传播更新模型参数,最小化损失。
  • evaluate函数:模型验证过程:

    • 在验证集上计算损失(不更新参数),监控模型是否过拟合(如验证损失上升但训练损失下降)。
7. 代码生成函数

基于训练好的模型,根据自然语言描述生成代码,并确保语法正确。

  • generate_valid_code函数
    • 编码输入文本:将自然语言描述转换为模型可处理的 token 序列。
    • 自回归生成:从<eos>开始,逐步预测下一个 token,直到生成<eos>或达到最大长度。
    • 采样策略:使用核采样(top-p)和温度调整控制生成多样性(温度越低,生成越保守;top-p 越小,过滤低概率 token 越严格)。
    • 多轮尝试与校验:最多尝试 3 次生成,若某次生成的代码通过 AST 语法校验,则返回该代码;否则返回最后一次生成的结果。
8. 主程序流程

串联上述所有模块,实现从数据准备到模型训练再到代码生成的完整流程:

  1. 数据准备:加载缓存数据或生成新数据(500 条 “文本 - 代码” 对)。
  2. 分词器准备:加载已训练的分词器或基于当前数据训练新分词器。
  3. 数据集划分:将数据按 9:1 划分为训练集(用于模型学习)和验证集(用于监控性能)。
  4. 模型初始化:创建 MoE 模型并设置关键参数(如词汇表大小、专家数量)。
  5. 模型训练:执行 15 轮训练,每 5 轮保存一次模型,绘制训练 / 验证损失曲线(直观展示模型收敛情况)。
  6. 代码生成示例:使用训练好的模型对中英文查询(如 “计算阶乘”“Check if a number is even”)生成代码,并输出结果及语法校验状态。

三、基于混合专家模型的代码生成的Python代码完整实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import ast
import json
import zipfile
from pathlib import Path
from sklearn.model_selection import train_test_split

# 配置中文字体
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# ----------------------------
# 超参数设置
# ----------------------------
MAX_TEXT_LEN = 128
MAX_CODE_LEN = 256
EMBEDDING_DIM = 256  # 减小维度以便在CPU上运行
FFN_DIM = 512
NUM_HEADS = 4
NUM_LAYERS = 2
NUM_EXPERTS = 4
TOP_K = 2
BATCH_SIZE = 8  # 减小批次大小
EPOCHS = 15  # 减少训练轮数
LEARNING_RATE = 3e-4


# ----------------------------
# 数据处理与分词器
# ----------------------------
class CodeDataset(Dataset):
    """自然语言-代码对数据集"""

    def __init__(self, data, text_tokenizer, code_tokenizer):
        self.data = data
        self.text_tokenizer = text_tokenizer
        self.code_tokenizer = code_tokenizer
        self.pad_id = code_tokenizer.token_to_id('<pad>')
        self.eos_id = code_tokenizer.token_to_id('<eos>')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, code = self.data[idx]

        # 自然语言编码
        text_enc = self.text_tokenizer.encode(text).ids
        text_enc = text_enc[:MAX_TEXT_LEN - 1] + [self.eos_id]
        text_enc += [self.pad_id] * (MAX_TEXT_LEN - len(text_enc))

        # 代码编码
        code_enc = self.code_tokenizer.encode(code).ids
        code_enc = [self.eos_id] + code_enc[:MAX_CODE_LEN - 2] + [self.eos_id]
        code_enc += [self.pad_id] * (MAX_CODE_LEN - len(code_enc))

        return (
            torch.tensor(text_enc, dtype=torch.long),
            torch.tensor(code_enc, dtype=torch.long)
        )


def generate_sample_data(num_samples=500):
    """生成示例数据"""
    base_samples = [
        "计算两个数的和\tdef add(a, b):\n    return a + b",
        "计算斐波那契数列第n项\tdef fibonacci(n):\n    if n <= 0:\n        return 0\n    elif n == 1:\n        return 1\n    else:\n        return fibonacci(n-1) + fibonacci(n-2)",
        "判断一个数是否为质数\tdef is_prime(n):\n    if n <= 1:\n        return False\n    for i in range(2, int(n**0.5)+1):\n        if n % i == 0:\n            return False\n    return True",
        "将列表中的元素去重\tdef unique_list(lst):\n    return list(set(lst))",
        "计算列表的平均值\tdef average(lst):\n    return sum(lst) / len(lst) if lst else 0",
        "将摄氏度转换为华氏度\tdef celsius_to_fahrenheit(c):\n    return c * 9/5 + 32",
        "统计字符串中单词的数量\tdef word_count(s):\n    return len(s.split())",
        "求两个数的最大公约数\tdef gcd(a, b):\n    while b:\n        a, b = b, a % b\n    return a",
        "生成指定范围的随机整数\timport random\ndef random_int(min_val, max_val):\n    return random.randint(min_val, max_val)",
        "检查字符串是否为回文\tdef is_palindrome(s):\n    s = s.lower().replace(' ', '')\n    return s == s[::-1]",
        "将列表元素从小到大排序\tdef sort_list(lst):\n    return sorted(lst)",
        "计算阶乘\tdef factorial(n):\n    if n == 0:\n        return 1\n    else:\n        return n * factorial(n-1)",
        "求列表中所有元素的和\tdef sum_list(lst):\n    return sum(lst)",
        "将字典的值转换为列表\tdef dict_values_to_list(d):\n    return list(d.values())",
        "检查列表是否为空\tdef is_empty(lst):\n    return len(lst) == 0"
    ]

    # 样本
    samples = base_samples.copy()
    while len(samples) < num_samples:
        # 随机选择一个基础样本并稍作修改
        import random
        base = random.choice(base_samples)
        text, code = base.split('\t', 1)

        # 随机修改函数名或变量名
        func_names = ["calculate", "compute", "get", "find", "determine"]
        var_names = ["x", "y", "val", "num", "a", "b", "i", "j"]

        if random.random() < 0.3:
            # 修改函数名
            old_func = code.split('def ')[1].split('(')[0]
            new_func = random.choice(func_names)
            code = code.replace(f'def {old_func}', f'def {new_func}')

        if random.random() < 0.2:
            # 修改变量名
            old_var = random.choice(var_names)
            new_var = random.choice([v for v in var_names if v != old_var])
            code = code.replace(old_var, new_var)

        samples.append(f"{text}\t{code}")

    return [s.split('\t', 1) for s in samples]


def train_tokenizers_from_data(text_data, code_data, save_dir):
    """从数据训练BPE分词器"""
    Path(save_dir).mkdir(exist_ok=True)

    # 1. 文本分词器(自然语言)
    text_tokenizer = Tokenizer(BPE(unk_token='<unk>'))
    text_tokenizer.pre_tokenizer = Whitespace()
    text_trainer = BpeTrainer(
        special_tokens=['<pad>', '<unk>', '<eos>'],
        vocab_size=3000
    )
    text_tokenizer.train_from_iterator(text_data, text_trainer)
    text_tokenizer.save(os.path.join(save_dir, 'text_tokenizer.json'))

    # 2. 代码分词器
    code_tokenizer = Tokenizer(BPE(unk_token='<unk>'))
    code_tokenizer.pre_tokenizer = Whitespace()
    code_trainer = BpeTrainer(
        special_tokens=['<pad>', '<unk>', '<eos>'],
        vocab_size=5000
    )
    code_tokenizer.train_from_iterator(code_data, code_trainer)
    code_tokenizer.save(os.path.join(save_dir, 'code_tokenizer.json'))

    return text_tokenizer, code_tokenizer


# ----------------------------
# AST语法校验功能
# ----------------------------
def is_valid_python_code(code):
    """使用AST模块检查Python代码的语法有效性"""
    try:
        ast.parse(code)
        return True, "语法正确"
    except SyntaxError as e:
        return False, f"语法错误: {str(e)}"
    except Exception as e:
        return False, f"解析错误: {str(e)}"


def improve_code_syntax(code):
    """简单修复常见的代码语法问题"""
    lines = code.split('\n')
    fixed_lines = []
    indent_level = 0
    indent_step = 4  # 使用4个空格缩进

    for line in lines:
        stripped = line.strip()

        # 减少缩进的情况
        if stripped.endswith(':') and not stripped.startswith(('elif', 'else')):
            fixed_lines.append(' ' * indent_level + stripped)
            indent_level += indent_step
        elif stripped.startswith(('return', 'break', 'continue')):
            fixed_lines.append(' ' * max(0, indent_level - indent_step) + stripped)
        else:
            fixed_lines.append(' ' * indent_level + stripped)

    return '\n'.join(fixed_lines)


# ----------------------------
# Transformer-MoE模型定义
# ----------------------------
class PositionalEncoding(nn.Module):
    """位置编码(Transformer必备组件)"""

    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (seq_len, batch_size, d_model)"""
        x = x + self.pe[:x.size(0)]
        return x


class TextEncoder(nn.Module):
    """自然语言编码器"""

    def __init__(self, vocab_size, embed_dim, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len)
        self.fc = nn.Linear(embed_dim, embed_dim)
        self.activation = nn.Tanh()

    def forward(self, x):
        x = self.embedding(x).permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        x = self.pos_encoder(x)
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, embed_dim)
        cls_feat = x.mean(dim=1)  # 取序列均值作为句子特征
        return self.activation(self.fc(cls_feat))


class Expert(nn.Module):
    """Transformer解码器专家网络"""

    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ffn_dim, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len)
        decoder_layers = TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ffn_dim,
            batch_first=True
        )
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, tgt, memory=None, tgt_mask=None, tgt_pad_mask=None):
        tgt_embed = self.embedding(tgt)  # (batch_size, tgt_len, embed_dim)
        tgt_embed = self.pos_encoder(tgt_embed.permute(1, 0, 2)).permute(1, 0, 2)

        output = self.transformer_decoder(
            tgt=tgt_embed,
            memory=tgt_embed if memory is None else memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_pad_mask
        )
        return self.fc_out(output)


class Gating(nn.Module):
    """门控网络"""

    def __init__(self, input_dim, num_experts, top_k):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.gate = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, num_experts)
        )

    def forward(self, x):
        raw_weights = self.gate(x)
        weights = F.softmax(raw_weights, dim=1)

        # 稀疏激活:仅保留top-k权重
        if self.top_k is not None and self.top_k < self.num_experts:
            top_k_values, top_k_indices = torch.topk(weights, self.top_k, dim=1)
            mask = torch.zeros_like(weights)
            mask.scatter_(1, top_k_indices, 1)
            weights = weights * mask
            weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
        return weights


class MoECodeGenerator(nn.Module):
    """MoE代码生成器"""

    def __init__(self, text_vocab_size, code_vocab_size, embed_dim, num_heads,
                 num_layers, ffn_dim, num_experts, top_k, max_text_len, max_code_len):
        super().__init__()
        self.text_encoder = TextEncoder(text_vocab_size, embed_dim, max_text_len)
        self.experts = nn.ModuleList([
            Expert(code_vocab_size, embed_dim, num_heads, num_layers, ffn_dim, max_code_len)
            for _ in range(num_experts)
        ])
        self.gating = Gating(embed_dim, num_experts, top_k)
        self.code_vocab_size = code_vocab_size
        self.pad_id = 0
        self.eos_id = None

    def set_eos_id(self, eos_id):
        self.eos_id = eos_id

    def generate_mask(self, size):
        """生成Transformer掩码"""
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, text, code):
        batch_size, code_len = code.size()
        text_feat = self.text_encoder(text)
        weights = self.gating(text_feat)

        tgt_mask = self.generate_mask(code_len).to(device)
        tgt_pad_mask = (code == self.pad_id).to(device)

        expert_logits = []
        for expert in self.experts:
            logits = expert(code, tgt_mask=tgt_mask, tgt_pad_mask=tgt_pad_mask)
            expert_logits.append(logits[:, :-1, :])

        expert_logits = torch.stack(expert_logits, dim=1)
        weights = weights.unsqueeze(-1).unsqueeze(-1)
        logits = torch.sum(expert_logits * weights, dim=1)

        return logits, weights.squeeze(-1).squeeze(-1)


# ----------------------------
# 损失函数与训练逻辑
# ----------------------------
def load_balancing_loss(gating_weights):
    """负载均衡损失"""
    expert_mean = gating_weights.mean(dim=0)
    target = torch.ones_like(expert_mean) / NUM_EXPERTS
    return F.mse_loss(expert_mean, target)


def train(model, loader, optimizer, criterion, lambda_lb):
    model.train()
    total_loss = 0.0
    for text, code in tqdm(loader, desc="训练"):
        text = text.to(device)
        code = code.to(device)
        optimizer.zero_grad()

        logits, weights = model(text, code)
        target = code[:, 1:]

        ce_loss = criterion(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
        lb_loss = load_balancing_loss(weights)
        loss = ce_loss + lambda_lb * lb_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * text.size(0)

    return total_loss / len(loader.dataset)


def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for text, code in tqdm(loader, desc="评估"):
            text = text.to(device)
            code = code.to(device)
            logits, _ = model(text, code)
            target = code[:, 1:]
            loss = criterion(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
            total_loss += loss.item() * text.size(0)

    return total_loss / len(loader.dataset)


# ----------------------------
# 代码生成函数
# ----------------------------
def generate_valid_code(model, text, text_tokenizer, code_tokenizer,
                        max_attempts=3, max_len=256, temperature=0.7, top_p=0.9):
    """生成并验证代码"""
    model.eval()
    pad_id = code_tokenizer.token_to_id('<pad>')
    eos_id = code_tokenizer.token_to_id('<eos>')

    # 编码自然语言
    text_enc = text_tokenizer.encode(text).ids
    text_enc = text_enc[:MAX_TEXT_LEN - 1] + [eos_id]
    text_enc += [pad_id] * (MAX_TEXT_LEN - len(text_enc))
    text_tensor = torch.tensor(text_enc, dtype=torch.long).unsqueeze(0).to(device)

    for attempt in range(max_attempts):
        # 初始化代码序列
        code_seq = [eos_id]
        code_tensor = torch.tensor(code_seq, dtype=torch.long).unsqueeze(0).to(device)

        with torch.no_grad():
            for _ in range(max_len):
                current_len = code_tensor.size(1)
                tgt_mask = model.generate_mask(current_len).to(device)
                tgt_pad_mask = (code_tensor == pad_id).to(device)

                # 获取专家权重
                text_feat = model.text_encoder(text_tensor)
                weights = model.gating(text_feat)

                # 专家预测
                expert_logits = []
                for expert in model.experts:
                    logits = expert(code_tensor, tgt_mask=tgt_mask, tgt_pad_mask=tgt_pad_mask)
                    expert_logits.append(logits[:, -1, :])

                # 加权组合
                expert_logits = torch.stack(expert_logits, dim=1)
                weights = weights.unsqueeze(-1)
                logits = torch.sum(expert_logits * weights, dim=1)

                # 温度调整
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)

                # 核采样
                sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                mask = cumulative_probs > top_p
                mask[..., 1:] = mask[..., :-1].clone()
                mask[..., 0] = 0
                sorted_probs[mask] = 0.0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

                # 采样下一个token
                next_idx = torch.multinomial(sorted_probs, num_samples=1).item()
                next_idx = sorted_indices[0, next_idx].item()

                if next_idx == eos_id:
                    break
                code_seq.append(next_idx)
                code_tensor = torch.cat([code_tensor, torch.tensor([[next_idx]], device=device)], dim=1)

        # 解码并检查语法
        generated_code = code_tokenizer.decode(code_seq[1:])
        improved_code = improve_code_syntax(generated_code)
        is_valid, msg = is_valid_python_code(improved_code)

        if is_valid:
            return improved_code, attempt + 1, "语法正确"
        else:
            print(f"生成尝试 {attempt + 1} 失败: {msg}")

    return improved_code, max_attempts, "达到最大尝试次数,返回可能存在语法错误的代码"


# ----------------------------
# 主程序
# ----------------------------
if __name__ == "__main__":
    # 数据和分词器路径
    data_dir = "local_data"
    tokenizer_dir = "local_tokenizers"
    cache_path = "local_data_cache.json"

    # 确保数据目录存在
    Path(data_dir).mkdir(exist_ok=True)

    # 加载或生成数据(使用本地生成的数据替代下载)
    if os.path.exists(cache_path):
        print("加载缓存的数据...")
        with open(cache_path, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
    else:
        print("生成本地示例数据...")
        all_data = generate_sample_data(num_samples=500)  # 生成500个样本

        # 缓存数据
        with open(cache_path, 'w', encoding='utf-8') as f:
            json.dump(all_data, f, ensure_ascii=False)

    print(f"加载了 {len(all_data)} 条代码-注释对")

    # 加载或训练分词器
    if os.path.exists(os.path.join(tokenizer_dir, 'text_tokenizer.json')):
        print("加载已训练的分词器...")
        text_tokenizer = Tokenizer.from_file(os.path.join(tokenizer_dir, 'text_tokenizer.json'))
        code_tokenizer = Tokenizer.from_file(os.path.join(tokenizer_dir, 'code_tokenizer.json'))
    else:
        print("正在训练分词器...")
        text_data = [item[0] for item in all_data]
        code_data = [item[1] for item in all_data]
        text_tokenizer, code_tokenizer = train_tokenizers_from_data(text_data, code_data, tokenizer_dir)

    # 划分训练集和验证集
    train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)

    # 构建数据集
    train_dataset = CodeDataset(train_data, text_tokenizer, code_tokenizer)
    val_dataset = CodeDataset(val_data, text_tokenizer, code_tokenizer)

    code_vocab_size = code_tokenizer.get_vocab_size()
    text_vocab_size = text_tokenizer.get_vocab_size()
    print(f"文本词汇表大小: {text_vocab_size}, 代码词汇表大小: {code_vocab_size}")


    # 数据加载器
    def collate_fn(batch):
        texts, codes = zip(*batch)
        return torch.stack(texts), torch.stack(codes)


    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            collate_fn=collate_fn)

    # 初始化模型
    model = MoECodeGenerator(
        text_vocab_size=text_vocab_size,
        code_vocab_size=code_vocab_size,
        embed_dim=EMBEDDING_DIM,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        ffn_dim=FFN_DIM,
        num_experts=NUM_EXPERTS,
        top_k=TOP_K,
        max_text_len=MAX_TEXT_LEN,
        max_code_len=MAX_CODE_LEN
    ).to(device)
    model.set_eos_id(code_tokenizer.token_to_id('<eos>'))

    # 损失函数与优化器
    criterion = nn.CrossEntropyLoss(ignore_index=code_tokenizer.token_to_id('<pad>'))
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    lambda_lb = 0.1

    # 训练模型
    train_losses, val_losses = [], []
    print("\n开始训练模型...")
    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")
        train_loss = train(model, train_loader, optimizer, criterion, lambda_lb)
        val_loss = evaluate(model, val_loader, criterion)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}")

        # 每5个epoch保存一次模型
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")

    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label="训练损失")
    plt.plot(val_losses, label="验证损失")
    plt.xlabel("Epoch")
    plt.ylabel("损失")
    plt.title("训练与验证损失曲线")
    plt.legend()
    plt.savefig("loss_curve.png")
    plt.show()

    # 生成代码示例
    print("\n===== 代码生成示例 =====")
    test_queries = [
        "Write a Python function to calculate the factorial of a number",
        "Python function to sort a list of numbers",
        "Implement a simple addition function",
        "Check if a number is even or odd",
        "Create a function that returns the sum of a list"
    ]

    # 中文查询示例
    chinese_queries = [
        "写一个Python函数来计算两个矩阵的乘积",
        "Python函数,将列表中的元素去重并保持原有顺序",
        "实现一个Python装饰器,用于计算函数执行时间",
        "写一个函数来验证电子邮件地址的格式是否正确",
        "创建一个简单的Python类来表示栈并实现基本操作"
    ]

    # 测试英文查询
    for query in test_queries:
        code, attempts, msg = generate_valid_code(
            model, query, text_tokenizer, code_tokenizer,
            max_attempts=3, max_len=200, temperature=0.6, top_p=0.9
        )
        print(f"查询: {query}")
        print(f"状态: {msg} (尝试了 {attempts} 次)")
        print(f"生成代码:\n{code}\n" + "-" * 80)

    # 测试中文查询
    for query in chinese_queries:
        code, attempts, msg = generate_valid_code(
            model, query, text_tokenizer, code_tokenizer,
            max_attempts=3, max_len=200, temperature=0.6, top_p=0.9
        )
        print(f"查询: {query}")
        print(f"状态: {msg} (尝试了 {attempts} 次)")
        print(f"生成代码:\n{code}\n" + "-" * 80)

四、程序运行结果展示

使用设备: cpu
加载缓存的数据...
加载了 500 条代码-注释对
加载已训练的分词器...
文本词汇表大小: 177, 代码词汇表大小: 339

开始训练模型...

Epoch 1/15

训练: 100%|██████████| 57/57 [00:54<00:00,  1.04it/s]
评估: 100%|██████████| 7/7 [00:00<00:00,  8.88it/s]
训练损失: 2.7673, 验证损失: 0.9751

Epoch 2/15
训练: 100%|██████████| 57/57 [00:52<00:00,  1.09it/s]
评估: 100%|██████████| 7/7 [00:00<00:00,  9.65it/s]
训练损失: 0.6274, 验证损失: 0.3590

Epoch 3/15
训练: 100%|██████████| 57/57 [00:52<00:00,  1.08it/s]
评估: 100%|██████████| 7/7 [00:00<00:00,  9.26it/s]
训练损失: 0.3264, 验证损失: 0.2405

Epoch 4/15
训练: 100%|██████████| 57/57 [00:52<00:00,  1.09it/s]
评估: 100%|██████████| 7/7 [00:00<00:00,  9.16it/s]
训练损失: 0.2284, 验证损失: 0.2034

Epoch 5/15
训练: 100%|██████████| 57/57 [00:55<00:00,  1.03it/s]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.68it/s]
训练损失: 0.1878, 验证损失: 0.1651

Epoch 6/15
训练: 100%|██████████| 57/57 [01:07<00:00,  1.18s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  9.16it/s]
训练损失: 0.1471, 验证损失: 0.1412

Epoch 7/15
训练: 100%|██████████| 57/57 [01:05<00:00,  1.15s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.06it/s]
训练损失: 0.1239, 验证损失: 0.1285

Epoch 8/15
训练: 100%|██████████| 57/57 [01:06<00:00,  1.18s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.62it/s]
训练损失: 0.1042, 验证损失: 0.1154

Epoch 9/15
训练: 100%|██████████| 57/57 [01:07<00:00,  1.19s/it]
评估: 100%|██████████| 7/7 [00:01<00:00,  6.97it/s]
训练损失: 0.0835, 验证损失: 0.0967

Epoch 10/15
训练: 100%|██████████| 57/57 [01:10<00:00,  1.24s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.70it/s]
训练损失: 0.0670, 验证损失: 0.0854

Epoch 11/15
训练: 100%|██████████| 57/57 [01:08<00:00,  1.20s/it]
评估: 100%|██████████| 7/7 [00:01<00:00,  6.78it/s]
训练损失: 0.0582, 验证损失: 0.0839

Epoch 12/15
训练: 100%|██████████| 57/57 [01:10<00:00,  1.25s/it]
评估: 100%|██████████| 7/7 [00:01<00:00,  6.64it/s]
训练损失: 0.0529, 验证损失: 0.0848

Epoch 13/15
训练: 100%|██████████| 57/57 [01:11<00:00,  1.25s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.44it/s]
训练损失: 0.0505, 验证损失: 0.0816

Epoch 14/15
训练: 100%|██████████| 57/57 [01:13<00:00,  1.29s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.64it/s]
训练损失: 0.0424, 验证损失: 0.0809

Epoch 15/15
训练: 100%|██████████| 57/57 [01:10<00:00,  1.23s/it]
评估: 100%|██████████| 7/7 [00:00<00:00,  7.59it/s]
训练损失: 0.0467, 验证损失: 0.0972

===== 代码生成示例 =====
查询: Write a Python function to calculate the factorial of a number
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: Python function to sort a list of numbers
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: Implement a simple addition function
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: Check if a number is even or odd
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: Create a function that returns the sum of a list
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: 写一个Python函数来计算两个矩阵的乘积
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: Python函数,将列表中的元素去重并保持原有顺序
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a + x
--------------------------------------------------------------------------------
查询: 实现一个Python装饰器,用于计算函数执行时间
状态: 语法正确 (尝试了 1 次)
生成代码:
def add ( a , x ): return a
--------------------------------------------------------------------------------
查询: 写一个函数来验证电子邮件地址的格式是否正确
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
查询: 创建一个简单的Python类来表示栈并实现基本操作
状态: 语法正确 (尝试了 1 次)
生成代码:
def dict_values_to_list ( d ): return list ( d . values ())
--------------------------------------------------------------------------------
 

五、总结

本文提出了一种基于Transformer和混合专家(MoE)结构的自然语言到Python代码生成模型。该模型通过文本编码器提取自然语言特征,门控网络选择最适合的专家网络生成代码,并利用语法校验确保输出正确性。实现流程包括:

(1)构建500条多样化训练数据;

(2)训练专用分词器处理自然语言和代码;

(3)设计MoE架构(4个专家)实现任务特定处理;

(4)结合交叉熵和负载均衡损失优化模型;

(5)实现自动语法校验和简单修复功能。

实验表明模型能生成符合语法的代码,但存在功能单一问题,需进一步优化专家网络能力和训练数据质量。

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值