LlamaFactory 是可以对多种语言模型进行训练的开源项目,自带 Gradio UI,可视化训练以及推理。本文将使用 LlamaFactory 训练 CodeGemma,让模型可以使用私有 API 生成代码。在日常的开发过程中,通常会依赖一个内部的框架,或者内部类库/组件库,开源模型在预训练过程中,是无法获取到这部分数据的,或者不是很全面。假设我们现在有一个股票的代码库,代码库中有一个函数可以获取股票的历史价格。首先用模型生成一段获取股票价格的代码,假设 jqdata 中的 Ticker 对象可以获取股票的历史数据。
from jqdata import ticker
def get_stock_price(ticker):
"""
This method retrieves the current stock price for a given ticker symbol.
Parameters:
- ticker (str): The stock ticker symbol (e.g., 'AAPL' for Apple, 'TSLA' for Tesla).
Returns:
- float: The current stock price.
"""
try:
# Download stock data for the given ticker
stock = ticker(ticker)
# Get the latest market data
stock_data = stock.history(period="1d")
# Extract and return the closing price for the day
current_price = stock_data['Close'][0]
return current_price
except Exception as e:
print(f"Error retrieving stock price for {ticker}: {e}")
return None
我们的目标是希望 CodeGemma 可以学习到这个类库中的获取股票历史价格的代码,在推理过程中,使用该方法获取股票的历史价格。
训练主要分为三部分,训练、评估、推理,由于本文的训练集很小,只有几条数据,评估用了 Split 0.1,这里的目的只是为跑完整个流程,在生产中需要准备更多的数据进行评估。
数据集准备
LlamaFactory 数据集有三个来源,HuggingFace、ModelScope 或者本地文件,这里为了方便,本文使用文件。文件格式支持 alpaca 和 sharegpt,两种格式如下:
sharegpt
{
"conversation": [
{
"role": "user",
"content": "Can you help me write a short story about a dragon and a knight?"
},
{
"role": "assistant",
"content": "Certainly! Here's a short story:\n\nOnce upon a time in a kingdom far away, a brave knight set out on a quest to find a fearsome dragon..."
},
{
"role": "user",
"content": "That's great! Can you make the dragon friendly instead of fearsome?"
},
{
"role": "assistant",
"content": "Of course! Let me revise the story:\n\nOnce upon a time in a kingdom far away, a brave knight set out on a quest and discovered a friendly dragon who..."
}
]
}
alpaca
{
"instruction": "Explain the concept of gravitational waves.",
"input": "",
"output": "Gravitational waves are ripples in the fabric of spacetime caused by some of the most violent and energetic processes in the universe. They are produced when massive bodies accelerate, such as during the collision of black holes or neutron stars. Predicted by Albert Einstein's general theory of relativity, these waves propagate outward from their source at the speed of light, carrying information about their origins and about the nature of gravity itself."
}
LlamaFactory 会将 ShareGPT 和 Alpaca 格式的数据集转为自己的数据集格式,并最终通过 Tokenizer 的 ChatTemplate 转换为训练数据。创建训练文件并在 data_info 中进行定义。
- 在 dataset_info.json 添加如下配置
"pydata": {
"file_name": "pydata.json"
},
- 准备数据文件 pydata.json
[
{
"instruction": "Write a Python function to get the current stock price for a given ticker symbol using the `jqdata` library.",
"input": "",
"output": "\n\nfrom jqdata import ticker\n\ndef get_stock_price(ticker_symbol):\n \"\"\"\n This method retrieves the current stock price for a given ticker symbol.\n\n Parameters:\n - ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple, 'TSLA' for Tesla).\n\n Returns:\n - float: The current stock price.\n \"\"\"\n try:\n # Download stock data for the given ticker\n stock = ticker(ticker_symbol)\n\n # Get the latest market data\n stock_data = stock.history(period=\"1d\")\n\n # Extract and return the closing price for the day\n current_price = stock_data['Close'][0]\n return current_price\n except Exception as e:\n print(f\"Error retrieving stock price for {ticker_symbol}: {e}\")\n return None\n\n"
},
{
"instruction": "How can I calculate the 5-day moving average (MA5) of a stock using Python and the `jqdata` library?",
"input": "",
"output": "\n\nfrom jqdata import ticker\n\ndef get_ma5(ticker_symbol):\n \"\"\"\n This method calculates the 5-day moving average (MA5) for a given ticker symbol.\n\n Parameters:\n - ticker_symbol (str): The stock ticker symbol.\n\n Returns:\n - float: The 5-day moving average of the closing prices.\n \"\"\"\n try:\n # Download stock data for the past 5 days\n stock = ticker(ticker_symbol)\n stock_data = stock.history(period=\"5d\")\n\n # Calculate the MA5\n ma5 = stock_data['Close'].mean()\n return ma5\n except Exception as e:\n print(f\"Error calculating MA5 for {ticker_symbol}: {e}\")\n return None\n"
},
{
"instruction": "Provide code that gets both the current stock price and the MA5 for a given ticker symbol using `jqdata`.",
"input": "",
"output": "\nfrom jqdata import ticker\n\ndef get_stock_data(ticker_symbol):\n \"\"\"\n This method retrieves the current stock price and calculates the MA5 for a given ticker symbol.\n\n Parameters:\n - ticker_symbol (str): The stock ticker symbol.\n\n Returns:\n - dict: A dictionary containing the current stock price and MA5.\n \"\"\"\n try:\n # Download stock data\n stock = ticker(ticker_symbol)\n\n # Get the latest market data for current price\n current_data = stock.history(period=\"1d\")\n current_price = current_data['Close'][0]\n\n # Get historical data for MA5\n historical_data = stock.history(period=\"5d\")\n ma5 = historical_data['Close'].mean()\n\n return {'current_price': current_price, 'ma5': ma5}\n except Exception as e:\n print(f\"Error retrieving data for {ticker_symbol}: {e}\")\n return None\n\n# Example usage:\ndata = get_stock_data('AAPL')\nif data:\n print(f\"Current Price: {data['current_price']}\")\n print(f\"MA5: {data['ma5']}\")\n\n"
},
{
"instruction": "Write a function to get the stock price of a company, ensuring it handles exceptions gracefully.",
"input": "",
"output": "\nfrom jqdata import ticker\n\ndef get_stock_data(ticker_symbol):\n \"\"\"\n This method retrieves the current stock price and calculates the MA5 for a given ticker symbol.\n\n Parameters:\n - ticker_symbol (str): The stock ticker symbol.\n\n Returns:\n - dict: A dictionary containing the current stock price and MA5.\n \"\"\"\n try:\n # Download stock data\n stock = ticker(ticker_symbol)\n\n # Get the latest market data for current price\n current_data = stock.history(period=\"1d\")\n current_price = current_data['Close'][0]\n\n # Get historical data for MA5\n historical_data = stock.history(period=\"5d\")\n ma5 = historical_data['Close'].mean()\n\n return {'current_price': current_price, 'ma5': ma5}\n except Exception as e:\n print(f\"Error retrieving data for {ticker_symbol}: {e}\")\n return None\n\n# Example usage:\ndata = get_stock_data('AAPL')\nif data:\n print(f\"Current Price: {data['current_price']}\")\n print(f\"MA5: {data['ma5']}\")\n\n"
},
{
"instruction": "How do I get the current price of a stock entered by the user in Python?",
"input": "",
"output": "\n\nfrom jqdata import ticker\n\n\ndef get_stock_price(ticker_symbol):\n \"\"\"\n Retrieves the current stock price for a given ticker symbol with error handling.\n\n Parameters:\n - ticker_symbol (str): The stock ticker symbol.\n\n Returns:\n - float: The current stock price, or None if an error occurs.\n \"\"\"\n try:\n stock = ticker(ticker_symbol)\n stock_data = stock.history(period=\"1d\")\n current_price = stock_data['Close'][0]\n return current_price\n except Exception as e:\n print(f\"Error retrieving stock price for {ticker_symbol}: {e}\")\n return None\n \n# Get ticker symbol from user input\nticker_symbol = input(\"Enter the stock ticker symbol: \")\ncurrent_price = get_stock_price(ticker_symbol)\nif current_price:\n print(f\"The current price of {ticker_symbol.upper()} is {current_price}\")\nelse:\n print(f\"Could not retrieve the price for {ticker_symbol.upper()}\")\n \n"
}
]
训练模型
启动 Notebook 是需要设置USE_MODELSCOPE_HUB环境变量,这样我们就可以使用魔搭上面的模型了。
USE_MODELSCOPE_HUB=1 jupyter lab
LlamaFactory 提供了可以化和命令行的方式进行训练,这里我们使用命令行方式进行训练。
%cd LLaMA-Factory/
import json
model_name="codegemma-7b-it"
train_model = f"AI-ModelScope/{model_name}"
args = dict(
stage="sft", # do supervised fine-tuning
do_train=True,
model_name_or_path=train_model, # use bnb-4bit-quantized Llama-3-8B-Instruct model
dataset="pydata", # use alpaca and identity datasets
template="gemma", # use llama3 prompt template
finetuning_type="lora", # use LoRA adapters to save memory
lora_target="all", # attach LoRA adapters to all linear layers
output_dir=f"{model_name}_lora", # the path to save LoRA adapters
per_device_train_batch_size=2, # the batch size
gradient_accumulation_steps=4, # the gradient accumulation steps
lr_scheduler_type="cosine", # use cosine learning rate scheduler
logging_steps=10, # log every 10 steps
warmup_ratio=0.1, # use warmup scheduler
save_steps=1000, # save checkpoint every 1000 steps
learning_rate=5e-5, # the learning rate
num_train_epochs=3.0, # the epochs of training
max_samples=500, # use 500 examples in each dataset
max_grad_norm=1.0, # clip gradient norm to 1.0
loraplus_lr_ratio=16.0, # use LoRA+ algorithm with lambda=16.0
fp16=True, # use float16 mixed precision training
)
json.dump(args, open("train_llama3.json", "w", encoding="utf-8"), indent=2)
!USE_MODELSCOPE_HUB=1 llamafactory-cli train train_llama3.json
通过日志可以看到模型的存储位置。
添加验证步骤,在配置信息中添加了 Eval 步骤,如下 ### eval:
args = dict(
stage="sft", # do supervised fine-tuning
do_train=True,
model_name_or_path=train_model, # use bnb-4bit-quantized Llama-3-8B-Instruct model
dataset="pydata", # use alpaca and identity datasets
template="gemma", # use llama3 prompt template
finetuning_type="lora", # use LoRA adapters to save memory
lora_target="all", # attach LoRA adapters to all linear layers
output_dir=f"{model_name}_lora", # the path to save LoRA adapters
per_device_train_batch_size=2, # the batch size
gradient_accumulation_steps=4, # the gradient accumulation steps
lr_scheduler_type="cosine", # use cosine learning rate scheduler
logging_steps=10, # log every 10 steps
warmup_ratio=0.1, # use warmup scheduler
save_steps=1000, # save checkpoint every 1000 steps
learning_rate=5e-5, # the learning rate
num_train_epochs=3.0, # the epochs of training
max_samples=500, # use 500 examples in each dataset
max_grad_norm=1.0, # clip gradient norm to 1.0
loraplus_lr_ratio=16.0, # use LoRA+ algorithm with lambda=16.0
fp16=True, # use float16 mixed precision training
### eval
val_size=0.1,
per_device_eval_batch_size=1,
eval_strategy="steps",
eval_steps=10,
)
json.dump(args, open("train_gemma.json", "w", encoding="utf-8"), indent=2)
!USE_MODELSCOPE_HUB=1 llamafactory-cli train train_gemma.json
再次运行,查看验证步骤,以及对应的 loss 值。
模型效果
运行模型推理,加载 Lora Adapter,看看模型是否学到我们的函数。
## 提示词
return ma5 ma10 ma15 using jqdata library, only return code
如果不加载 Lora,使用原始模型返回的结果如下,Gemma 预训练模型在公网上学到了 jqdata,也就是聚宽的代码。
总结
LlamaFactory + Lora 可以很方便的实现模型知识的扩充,在数据集完整的条件下,直接选择训练模型,会比使用 RAG 效果更好。