LlamaFactory 训练自己的代码模型

LlamaFactory 是可以对多种语言模型进行训练的开源项目,自带 Gradio UI,可视化训练以及推理。本文将使用 LlamaFactory 训练 CodeGemma,让模型可以使用私有 API 生成代码。在日常的开发过程中,通常会依赖一个内部的框架,或者内部类库/组件库,开源模型在预训练过程中,是无法获取到这部分数据的,或者不是很全面。假设我们现在有一个股票的代码库,代码库中有一个函数可以获取股票的历史价格。首先用模型生成一段获取股票价格的代码,假设 jqdata 中的 Ticker 对象可以获取股票的历史数据。

from jqdata import ticker

def get_stock_price(ticker):
    """
    This method retrieves the current stock price for a given ticker symbol.
    
    Parameters:
    - ticker (str): The stock ticker symbol (e.g., 'AAPL' for Apple, 'TSLA' for Tesla).
    
    Returns:
    - float: The current stock price.
    """
    try:
        # Download stock data for the given ticker
        stock = ticker(ticker)
        
        # Get the latest market data
        stock_data = stock.history(period="1d")
        
        # Extract and return the closing price for the day
        current_price = stock_data['Close'][0]
        return current_price
    except Exception as e:
        print(f"Error retrieving stock price for {ticker}: {e}")
        return None

我们的目标是希望 CodeGemma 可以学习到这个类库中的获取股票历史价格的代码,在推理过程中,使用该方法获取股票的历史价格。
在这里插入图片描述
训练主要分为三部分,训练、评估、推理,由于本文的训练集很小,只有几条数据,评估用了 Split 0.1,这里的目的只是为跑完整个流程,在生产中需要准备更多的数据进行评估。

数据集准备

LlamaFactory 数据集有三个来源,HuggingFace、ModelScope 或者本地文件,这里为了方便,本文使用文件。文件格式支持 alpaca 和 sharegpt,两种格式如下:
sharegpt

{
  "conversation": [
    {
      "role": "user",
      "content": "Can you help me write a short story about a dragon and a knight?"
    },
    {
      "role": "assistant",
      "content": "Certainly! Here's a short story:\n\nOnce upon a time in a kingdom far away, a brave knight set out on a quest to find a fearsome dragon..."
    },
    {
      "role": "user",
      "content": "That's great! Can you make the dragon friendly instead of fearsome?"
    },
    {
      "role": "assistant",
      "content": "Of course! Let me revise the story:\n\nOnce upon a time in a kingdom far away, a brave knight set out on a quest and discovered a friendly dragon who..."
    }
  ]
}

alpaca

{
  "instruction": "Explain the concept of gravitational waves.",
  "input": "",
  "output": "Gravitational waves are ripples in the fabric of spacetime caused by some of the most violent and energetic processes in the universe. They are produced when massive bodies accelerate, such as during the collision of black holes or neutron stars. Predicted by Albert Einstein's general theory of relativity, these waves propagate outward from their source at the speed of light, carrying information about their origins and about the nature of gravity itself."
}

LlamaFactory 会将 ShareGPT 和 Alpaca 格式的数据集转为自己的数据集格式,并最终通过 Tokenizer 的 ChatTemplate 转换为训练数据。创建训练文件并在 data_info 中进行定义。

  • 在 dataset_info.json 添加如下配置
  "pydata": {
    "file_name": "pydata.json"
  },
  • 准备数据文件 pydata.json
[
  {
    "instruction": "Write a Python function to get the current stock price for a given ticker symbol using the `jqdata` library.",
    "input": "",
    "output": "\n\nfrom jqdata import ticker\n\ndef get_stock_price(ticker_symbol):\n    \"\"\"\n    This method retrieves the current stock price for a given ticker symbol.\n\n    Parameters:\n    - ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple, 'TSLA' for Tesla).\n\n    Returns:\n    - float: The current stock price.\n    \"\"\"\n    try:\n        # Download stock data for the given ticker\n        stock = ticker(ticker_symbol)\n\n        # Get the latest market data\n        stock_data = stock.history(period=\"1d\")\n\n        # Extract and return the closing price for the day\n        current_price = stock_data['Close'][0]\n        return current_price\n    except Exception as e:\n        print(f\"Error retrieving stock price for {ticker_symbol}: {e}\")\n        return None\n\n"
  },
  {
    "instruction": "How can I calculate the 5-day moving average (MA5) of a stock using Python and the `jqdata` library?",
    "input": "",
    "output": "\n\nfrom jqdata import ticker\n\ndef get_ma5(ticker_symbol):\n    \"\"\"\n    This method calculates the 5-day moving average (MA5) for a given ticker symbol.\n\n    Parameters:\n    - ticker_symbol (str): The stock ticker symbol.\n\n    Returns:\n    - float: The 5-day moving average of the closing prices.\n    \"\"\"\n    try:\n        # Download stock data for the past 5 days\n        stock = ticker(ticker_symbol)\n        stock_data = stock.history(period=\"5d\")\n\n        # Calculate the MA5\n        ma5 = stock_data['Close'].mean()\n        return ma5\n    except Exception as e:\n        print(f\"Error calculating MA5 for {ticker_symbol}: {e}\")\n        return None\n"
  },
  {
    "instruction": "Provide code that gets both the current stock price and the MA5 for a given ticker symbol using `jqdata`.",
    "input": "",
    "output": "\nfrom jqdata import ticker\n\ndef get_stock_data(ticker_symbol):\n    \"\"\"\n    This method retrieves the current stock price and calculates the MA5 for a given ticker symbol.\n\n    Parameters:\n    - ticker_symbol (str): The stock ticker symbol.\n\n    Returns:\n    - dict: A dictionary containing the current stock price and MA5.\n    \"\"\"\n    try:\n        # Download stock data\n        stock = ticker(ticker_symbol)\n\n        # Get the latest market data for current price\n        current_data = stock.history(period=\"1d\")\n        current_price = current_data['Close'][0]\n\n        # Get historical data for MA5\n        historical_data = stock.history(period=\"5d\")\n        ma5 = historical_data['Close'].mean()\n\n        return {'current_price': current_price, 'ma5': ma5}\n    except Exception as e:\n        print(f\"Error retrieving data for {ticker_symbol}: {e}\")\n        return None\n\n# Example usage:\ndata = get_stock_data('AAPL')\nif data:\n    print(f\"Current Price: {data['current_price']}\")\n    print(f\"MA5: {data['ma5']}\")\n\n"
  },
  {
    "instruction": "Write a function to get the stock price of a company, ensuring it handles exceptions gracefully.",
    "input": "",
    "output": "\nfrom jqdata import ticker\n\ndef get_stock_data(ticker_symbol):\n    \"\"\"\n    This method retrieves the current stock price and calculates the MA5 for a given ticker symbol.\n\n    Parameters:\n    - ticker_symbol (str): The stock ticker symbol.\n\n    Returns:\n    - dict: A dictionary containing the current stock price and MA5.\n    \"\"\"\n    try:\n        # Download stock data\n        stock = ticker(ticker_symbol)\n\n        # Get the latest market data for current price\n        current_data = stock.history(period=\"1d\")\n        current_price = current_data['Close'][0]\n\n        # Get historical data for MA5\n        historical_data = stock.history(period=\"5d\")\n        ma5 = historical_data['Close'].mean()\n\n        return {'current_price': current_price, 'ma5': ma5}\n    except Exception as e:\n        print(f\"Error retrieving data for {ticker_symbol}: {e}\")\n        return None\n\n# Example usage:\ndata = get_stock_data('AAPL')\nif data:\n    print(f\"Current Price: {data['current_price']}\")\n    print(f\"MA5: {data['ma5']}\")\n\n"
  },
  {
    "instruction": "How do I get the current price of a stock entered by the user in Python?",
    "input": "",
    "output": "\n\nfrom jqdata import ticker\n\n\ndef get_stock_price(ticker_symbol):\n    \"\"\"\n    Retrieves the current stock price for a given ticker symbol with error handling.\n\n    Parameters:\n    - ticker_symbol (str): The stock ticker symbol.\n\n    Returns:\n    - float: The current stock price, or None if an error occurs.\n    \"\"\"\n    try:\n        stock = ticker(ticker_symbol)\n        stock_data = stock.history(period=\"1d\")\n        current_price = stock_data['Close'][0]\n        return current_price\n    except Exception as e:\n        print(f\"Error retrieving stock price for {ticker_symbol}: {e}\")\n        return None\n    \n# Get ticker symbol from user input\nticker_symbol = input(\"Enter the stock ticker symbol: \")\ncurrent_price = get_stock_price(ticker_symbol)\nif current_price:\n    print(f\"The current price of {ticker_symbol.upper()} is {current_price}\")\nelse:\n    print(f\"Could not retrieve the price for {ticker_symbol.upper()}\")\n       \n"
  }
]


训练模型

启动 Notebook 是需要设置USE_MODELSCOPE_HUB环境变量,这样我们就可以使用魔搭上面的模型了。

USE_MODELSCOPE_HUB=1 jupyter lab

LlamaFactory 提供了可以化和命令行的方式进行训练,这里我们使用命令行方式进行训练。

%cd LLaMA-Factory/
import json
model_name="codegemma-7b-it"
train_model = f"AI-ModelScope/{model_name}"

args = dict(
  stage="sft",                        # do supervised fine-tuning
  do_train=True,
  model_name_or_path=train_model, # use bnb-4bit-quantized Llama-3-8B-Instruct model
  dataset="pydata",             # use alpaca and identity datasets
  template="gemma",                     # use llama3 prompt template
  finetuning_type="lora",                   # use LoRA adapters to save memory
  lora_target="all",                     # attach LoRA adapters to all linear layers
  output_dir=f"{model_name}_lora",                  # the path to save LoRA adapters
  per_device_train_batch_size=2,               # the batch size
  gradient_accumulation_steps=4,               # the gradient accumulation steps
  lr_scheduler_type="cosine",                 # use cosine learning rate scheduler
  logging_steps=10,                      # log every 10 steps
  warmup_ratio=0.1,                      # use warmup scheduler
  save_steps=1000,                      # save checkpoint every 1000 steps
  learning_rate=5e-5,                     # the learning rate
  num_train_epochs=3.0,                    # the epochs of training
  max_samples=500,                      # use 500 examples in each dataset
  max_grad_norm=1.0,                     # clip gradient norm to 1.0
  loraplus_lr_ratio=16.0,                   # use LoRA+ algorithm with lambda=16.0
  fp16=True,                         # use float16 mixed precision training
)

json.dump(args, open("train_llama3.json", "w", encoding="utf-8"), indent=2)



!USE_MODELSCOPE_HUB=1 llamafactory-cli train train_llama3.json

通过日志可以看到模型的存储位置。
在这里插入图片描述
添加验证步骤,在配置信息中添加了 Eval 步骤,如下 ### eval:


args = dict(
  stage="sft",                        # do supervised fine-tuning
  do_train=True,
  model_name_or_path=train_model, # use bnb-4bit-quantized Llama-3-8B-Instruct model
  dataset="pydata",             # use alpaca and identity datasets
  template="gemma",                     # use llama3 prompt template
  finetuning_type="lora",                   # use LoRA adapters to save memory
  lora_target="all",                     # attach LoRA adapters to all linear layers
  output_dir=f"{model_name}_lora",                  # the path to save LoRA adapters
  per_device_train_batch_size=2,               # the batch size
  gradient_accumulation_steps=4,               # the gradient accumulation steps
  lr_scheduler_type="cosine",                 # use cosine learning rate scheduler
  logging_steps=10,                      # log every 10 steps
  warmup_ratio=0.1,                      # use warmup scheduler
  save_steps=1000,                      # save checkpoint every 1000 steps
  learning_rate=5e-5,                     # the learning rate
  num_train_epochs=3.0,                    # the epochs of training
  max_samples=500,                      # use 500 examples in each dataset
  max_grad_norm=1.0,                     # clip gradient norm to 1.0
  loraplus_lr_ratio=16.0,                   # use LoRA+ algorithm with lambda=16.0
  fp16=True,                         # use float16 mixed precision training
### eval
val_size=0.1,
per_device_eval_batch_size=1,
eval_strategy="steps",
eval_steps=10,

)

json.dump(args, open("train_gemma.json", "w", encoding="utf-8"), indent=2)



!USE_MODELSCOPE_HUB=1 llamafactory-cli train train_gemma.json

再次运行,查看验证步骤,以及对应的 loss 值。

模型效果

运行模型推理,加载 Lora Adapter,看看模型是否学到我们的函数。

## 提示词
return ma5 ma10 ma15 using jqdata library, only return code

在这里插入图片描述
如果不加载 Lora,使用原始模型返回的结果如下,Gemma 预训练模型在公网上学到了 jqdata,也就是聚宽的代码。

在这里插入图片描述

总结

LlamaFactory + Lora 可以很方便的实现模型知识的扩充,在数据集完整的条件下,直接选择训练模型,会比使用 RAG 效果更好。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值