教你用一部电影的时间训练个人专属Agent

本期摘要

Agent是一个超越简单文本生成的人工智能系统。它使用大型语言模型(LLM)作为其中央计算引擎,使其能够进行对话、执行任务、推理并显示一定程度的自主权。

本文带你手把手尝试如何训练一个属于自己的私人订制Agent。

01

Agent如何工作?

1、当用户给出一个任务task之后可以从memory中查询记录(可选),查询出的结果(如果有)给AgentLLM进行判断是否可复用,这里指的复用是针对时效性没那么高的任务,例如对过去时的数据“中国19-22年的出生及死亡人口数据”,但如果查询股票数据,天气这种对时效性有很高要求的任务则不适合复用。

2、Agent对任务实现的方式有很多,可以拆解任务、使用lCOT或REACT框架、SOP(Standard Operating Procedure)标准作业规程等等。其目的都是将一个复杂的任务分成n个可在one step内即可完成的子任务。

3、对于子任务,是否需要调用工具,如果无需调用工具则只需要进行一次推理即可;对于需要调用工具的子任务AgentLLM会根据任务描述调用一个或多个工具,根据工具返回结果判断是否可以更改任务状态。待所有的子任务都完成状态变更之后AgentLLM会对结果进行评估反思,判断当前任务是否已经完成。如果某些子任务因为种种原因无法完成,AgentLLM会采取别的方法完成此任务,重复以上步骤直到可以给出结果为止,当然这里的Loop需要设置最大重试次数避免死循环。

4、当AgentLLM判断可以完成任务后可以进行历史任务存储(可选)。长期记忆是将数据存储在数据库中,以便下次查询,短期记忆则保存在内存或缓存中,程序结束时释放。

02

Function Call 原理

在一些任务中我们希望LLM返回我们格式化的数据如json、xml等,function call则需要LLM返回特定的json格式,以OpenAI为例,需要提供工具的描述信息。

from openai import OpenAI  
import json  
  
client = OpenAI()  
  
  
  
def get_current_weather(location, unit="fahrenheit"):  
    """Get the current weather in a given location"""  
    if "tokyo" in location.lower():  
        return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})  
    elif "san francisco" in location.lower():  
        return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})  
    elif "paris" in location.lower():  
        return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})  
    else:  
        return json.dumps({"location": location, "temperature": "unknown"})  
  
def run_conversation():  
      
    messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]  
    tools = [  
        {  
            "type": "function",  
            "function": {  
                "name": "get_current_weather",  
                "description": "Get the current weather in a given location",  
                "parameters": {  
                    "type": "object",  
                    "properties": {  
                        "location": {  
                            "type": "string",  
                            "description": "The city and state, e.g. San Francisco, CA",  
                        },  
                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},  
                    },  
                    "required": ["location"],  
                },  
            },  
        }  
    ]  
    response = client.chat.completions.create(  
        model="gpt-3.5-turbo-1106",  
        messages=messages,  
        tools=tools,  
        tool_choice="auto",   
    )  
    response_message = response.choices[0].message  
    tool_calls = response_message.tool_calls  
      
    if tool_calls:  
          
          
        available_functions = {  
            "get_current_weather": get_current_weather,  
        }   
        messages.append(response_message)   
          
        for tool_call in tool_calls:  
            function_name = tool_call.function.name  
            function_to_call = available_functions[function_name]  
            function_args = json.loads(tool_call.function.arguments)  
            function_response = function_to_call(  
                location=function_args.get("location"),  
                unit=function_args.get("unit"),  
            )  
            messages.append(  
                {  
                    "tool_call_id": tool_call.id,  
                    "role": "tool",  
                    "name": function_name,  
                    "content": function_response,  
                }  
            )   
        second_response = client.chat.completions.create(  
            model="gpt-3.5-turbo-1106",  
            messages=messages,  
        )   
        return second_response  
print(run_conversation())

在推理结果中可以拿到类似{“name”: “get_current_weather”, “params”: {“location”: “北京”, “unit”: “celsius”}}这样的json数据,这里有需要调用的工具名称以及参数信息,接下来只需要编写代码实现工具调用,将工具返回的结果构造成message加入到与LLM对话的上下文中即可实现工具调用。这里的难点在于对一个开源模型来说,如何根据任务以及提供的工具描述给出正确的工具名称以及正确的参数。

03

开源模型工具调用微调

  • 开源项目地址:LLaMA-Factory(https://github.com/hiyouga/LLaMA-Factory);

  • 作者知乎最佳实践地址:单卡 3 小时训练专属大模型 Agent:基于 LLaMA Factory 实战(https://zhuanlan.zhihu.com/p/678989191#ref_10)。

以下为复现实验数据过程记录

Chat模型微调

模型 Yi-6B-Chat 硬件信息 NVIDIA A100-SXM4-80GB
sft超参

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \   
    --stage sft \   
    --do_train \   
    --model_name_or_path /mnt/models/Yi-6B-Chat \   
    --dataset glaive_toolcall \   
    --template yi \   
    --finetuning_type lora \   
    --lora_target q_proj,v_proj \   
    --output_dir yi_agent_checkopint \   
    --lora_target all \   
    --overwrite_cache \   
    --per_device_train_batch_size 4 \   
    --gradient_accumulation_steps 4 \   
    --lr_scheduler_type cosine \   
    --logging_steps 10 \   
    --save_steps 1000 \   
    --learning_rate 5e-4 \   
    --num_train_epochs 3 \   
    --plot_loss \   
    --fp16

export model

python src/export_model.py \   
    --model_name_or_path /mnt/models/Yi-6B-Chat \   
    --adapter_name_or_path yi_agent_checkopint \   
    --template yi \   
    --finetuning_type lora \   
    --export_dir Yi-Agent-6b-Chat \   
    --export_size 2 \   
    --export_legacy_format False

web demo

python src/web_demo.py --model_name_or_path Yi-Agent-6b-Chat --template yi

训练过程日志

{'train_runtime': 7735.6787, 'train_samples_per_second': 3.878, 'train_steps_per_second': 0.242, 'train_loss': 0.3381453339894613, 'epoch': 3.0}  
100%|████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [2:08:55<00:00, 4.13s/it]  
[INFO|trainer.py:2889] 2024-01-25 13:39:49,599 >> Saving model checkpoint to yi_agent_checkopint  
[INFO|tokenization_utils_base.py:2432] 2024-01-25 13:39:49,709 >> tokenizer config file saved in yi_agent_checkopint/tokenizer_config.json  
[INFO|tokenization_utils_base.py:2441] 2024-01-25 13:39:49,709 >> Special tokens file saved in yi_agent_checkopint/special_tokens_map.json  
***** train metrics *****  
  epoch = 3.0  
  train_loss = 0.3381  
  train_runtime = 2:08:55.67  
  train_samples_per_second = 3.878  
  train_steps_per_second = 0.242  
Figure saved: yi_agent_checkopint/training_loss.png  
01/25/2024 13:39:49 - WARNING - llmtuner.extras.ploting - No metric eval_loss to plot.  
[INFO|modelcard.py:452] 2024-01-25 13:39:49,848 >> Dropping the following result as it does not have all the necessary fields:  
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}

测试结果

测试Tools

[  
    {  
        "name": "get_province_list",  
        "description": "获取省份ID",  
        "parameters": {  
            "type": "object",  
            "properties": {}  
        }  
    },  
    {  
        "name": "get_cities_list",  
        "description": "根据省份ID查询城市地区ID",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "province_id": {  
                    "type": "string",  
                    "description": "省份ID,可以通过调用get_province_list获取省份ID"  
                }  
            },  
            "required": [  
                "province_id"  
            ]  
        }  
    },  
    {  
        "name": "get_history_weather",  
        "description": "根据城市ID和日期查询历史天气信息,日期支持从2011-01-01开始。注:个别地区个别日期数据记录可能会不存在",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "city_id": {  
                    "type": "string",  
                    "description": "城市地区ID,可以通过调用get_cities_list获取城市地区ID"  
                },  
                "weather_date": {  
                    "type": "string",  
                    "description": "日期,格式:2017-07-15,日期不能大于等于今日日期"  
                }  
            },  
            "required": [  
                "city_id",  
                "weather_date"  
            ]  
        }  
    },  
    {  
        "name": "get_river_environment",  
        "description": "查询地表水水质",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "page": {  
                    "type": "integer",  
                    "description": "第几页(默认1)"  
                },  
                "province": {  
                    "type": "string",  
                    "description": "省份,例:江苏省"  
                },  
                "river": {  
                    "type": "string",  
                    "description": "流域,例:海河流域"  
                },  
                "section": {  
                    "type": "string",  
                    "description": "断面名称,例:鼓楼外大街"  
                }  
            },  
            "required": []  
        }  
    },  
    {  
        "name": "get_environment_air_pm",  
        "description": "查询的城市PM2.5数据",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "city": {  
                    "type": "string",  
                    "description": "城市名称的中文名称或拼音,如:上海 或 shanghai"  
                }  
            },  
            "required": [  
                "city"  
            ]  
        }  
    },  
    {  
        "name": "get_toutiao_news",  
        "description": "新闻列表查询",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "type": {  
                    "type": "string",  
                    "description": "支持类型 top(推荐,默认) guonei(国内) guoji(国际) yule(娱乐) tiyu(体育) junshi(军事) keji(科技) caijing(财经) youxi(游戏) qiche(汽车) jiankang(健康)"  
                },  
                "page": {  
                    "type": "string",  
                    "description": "当前页数, 默认1, 最大50"  
                },  
                "page_size": {  
                    "type": "string",  
                    "description": "每页返回条数, 默认30 , 最大30"  
                },  
                "is_filter": {  
                    "type": "string",  
                    "description": "是否只返回有内容详情的新闻, 1:是, 默认0"  
                }  
            },  
            "required": []  
        }  
    },  
    {  
        "name": "chejian_query",  
        "description": "根据车辆注册日期及类型,计算车辆的下次上线检验时间。本计算结果仅供参考。",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "type": {  
                    "type": "string",  
                    "description": "车辆类型, 3:9座(含)以下非营运小微型载客汽车(面包车除外) 4:摩托车 7:非营运大型轿车 1:营运车辆 2:货车、大中型客车 6:面包车 5:其他机动车"  
                },  
                "reg_date": {  
                    "type": "string",  
                    "description": "注册登记日期,格式:2022-11-02"  
                },  
                "iis_sg": {  
                    "type": "integer",  
                    "description": "事故情况(是否发生过致人伤亡事故或存在非法改装被依法处罚的交通违法),如是传1"  
                }  
            },  
            "required": [  
                "type",  
                "reg_date"  
            ]  
        }  
    },  
    {  
        "name": "loan_calc_query",  
        "description": "公积金贷款计算器用于计算用户在申请公积金贷款时,选择等额本金和等额本息两种不同的还款方式后,每一期需偿还公积金贷款的月供,以及利息总额和还款总额。",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "money": {  
                    "type": "integer",  
                    "description": "贷款金额(0 < money <= 500),单位(万),如70表示70万;"  
                },  
                "year": {  
                    "type": "integer",  
                    "description": "贷款年限,单位(年),仅限输入 5、10、15、20、25、30"  
                },  
                "active": {  
                    "type": "string",  
                    "description": "贷款利率,默认3.25"  
                }  
            },  
            "required": [  
                "money",  
                "year"  
            ]  
        }  
    },  
    {  
        "name": "icp_query",  
        "description": "网站icp备案查询",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "domainName": {  
                    "type": "string",  
                    "description": "获取的域名,如:juhe.cn"  
                }  
            },  
            "required": [  
                "domainName"  
            ]  
        }  
    },  
    {  
        "name": "airport_query",  
        "description": "获取全球机场三字码",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "airport": {  
                    "type": "string",  
                    "description": "关键词(可匹配城市机场的中英文名称、机场三字码)"  
                },  
                "page": {  
                    "type": "integer",  
                    "description": "页码(默认为1)"  
                },  
                "per_page": {  
                    "type": "integer",  
                    "description": "每页显示数量(默认为20,最大为100)"  
                }  
            },  
            "required": [  
                "airport"  
            ]  
        }  
    },  
    {  
        "name": "aptabnormal_query",  
        "description": "根据机场三字码查询国内机场不正常航班列表",  
        "parameters": {  
            "type": "object",  
            "properties": {  
                "airport": {  
                    "type": "string",  
                    "description": "机场三字码,字母大写(如:PEK),可通过airport_query获取三字码"  
                }  
            },  
            "required": [  
                "airport"  
            ]  
        }  
    }  
]

测试问题

请参考工具调用能力测试中的场景列(https://www.yuque.com/mrbun/sgr5h5/hsnz17g1a1wr6k2t#KmgD)

04

预训练模型微调

硬件信息 NVIDIA-4090 24G 单卡
sft超参

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \   
    --stage sft \   
    --do_train \   
    --model_name_or_path /data/models/Yi-6B \   
    --dataset glaive_toolcall,alpaca_gpt4_en,alpaca_gpt4_zh,oaast_sft_zh \   
    --max_samples 8000 \   
    --template default \   
    --finetuning_type lora \   
    --lora_target q_proj,v_proj \   
    --output_dir yi_agent_checkopint \   
    --lora_target all \   
    --overwrite_cache \   
    --per_device_train_batch_size 1 \   
    --gradient_accumulation_steps 4 \   
    --lr_scheduler_type cosine \   
    --logging_steps 10 \   
    --save_steps 1000 \   
    --learning_rate 5e-5 \   
    --num_train_epochs 2 \   
    --plot_loss \   
    --fp16 \   
    --flash_attn

export model

python src/export_model.py \   
    --model_name_or_path /data/models/Yi-6B \   
    --adapter_name_or_path /data/projects/LLaMA-Factory/yi_agent_checkopint \   
    --template default \   
    --finetuning_type lora \   
    --export_dir Yi-Agent-6B-Chat \   
    --export_size 2 \   
    --export_legacy_format False

web demo

python src/web_demo.py \   
    --model_name_or_path Yi-Agent-6B-Chat \   
    --template default

测试结果不再赘述。

05

总结

通过SFT微调后可以让原本不具备工具调用能力的模型实现工具调用。通过测试结果可以看出对于复杂场景的效果不是很好,单工具的场景正确率很高,测试的场景是中文场景,训练集中是英文,泛化效果也很不错,我正在准备以下类型数据集,如果有类似的数据集可以在下面贴出连接。

  • API参数描述中需要调用另外一个接口拿到的场景,例如天气查询中的城市id需要调用获取城市idAPI拿到。

  • 对于问题中参数信息不完整,主动抛出问题获取更详细参数信息的场景。

  • 多工具场景。

模型已发布到modelscope Yi-Agent-6B-Chat

(https://modelscope.cn/models/mrsteamedbun/Yi-Agent-6B-Chat/summary)

如何学习AI大模型?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。

  • 20
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值