InternLM微调孙悟空卖货主播模型(一):数据集生成

以InternLM为基座模型,使用Xtuner微调一个孙悟空卖货主播大模型。本文主要是为了复现大佬的成果,将大佬自己训练的主播,替换成孙悟空,使得模型可以以孙悟空的口吻生成带货文案,以及对话,最后生成孙悟空形象的数字人带货主播。为什么选择孙悟空呢,主要是想蹭一下最近很火的黑神话悟空的热度(哈哈哈)。本文主要记录下,模型微调数据的生成。

数据集制作

我是做传统机器学习的,所见过的数据集大部分都是x-y这样的形式,在刚接触大模型的时候就很好奇这种多轮对话要如何构建数据集,而且作为个人研究,如何获取自己想要的数据集呢。大佬的思路是这样的,首先需要有日常消费的产品数据,以及产品的亮点或者特色,以及客户在购买东西时经常会提问的一些问题。那么这些数据如何来呢。使用其他的商用大模型是一个很好的思路。

日常消费品数据获取

这里我使用的是kimi,大佬这里提供了两个prompt生成了产品列表

# 第一个 prompt: 帮我列举10种常用的消费品种类,并每种举例5个其子类
# 每个类 prompt: 现在你精通任何产品,你可以帮我举例每个产品的6个亮点或特点,, 然后用python dict形式输出:{类名:[特点1, 特点2] ...} ,去掉特点12的字样,除python字典外的其他都不要输出,不要有任何的警告信息。 [xxx]

生成效果如下(由于数据较多,这里只展示一部分)

以上是大佬生成的数据。

我在实践的过程中,为了让kimi生成一样的格式,进行了多轮对话修正才让kimi生成了类似的格式。

  • 列举常见的消费产品

  • 生成需要的格式(中间进行了多轮对话,中间过程省略)

这里设置成功后获取了和大佬结构一致的初始数据设置。

同理,购物时客户常问的问题,也通过kimi生成。

文案与对话数据集生成

有了产品,产品的特性,以及客户常见的问题,如何生成用于模型训练的微调数据集呢。

我们需要的最终的对话数据集的形式如下所示。

{
      "conversation": [
        {
          "system": 全局设定
          "output": 直播带货口播文案,格式化一行输出,不要换行。
        },
        {
          "input": 消费者的问题,
          "output": 主播回答
        },
        {
          "input": 消费者的问题,
          "output": 主播回答
        },
        ... 直到问题结束
      ]
    }

如何制作这样的数据呢,我们现在有了产品,产品特征,客户的常见问题,如何生成对话数据集呢。

这里作者定义了每个产品的conversation数量,对话轮数,并使用商用大模型生成直播带货的口播文案,以及跟根据客户常见的问题,生成问题与回答。

这里作者使用的是通义千问、和文心一言的api,但是由于我没有免费额度,又舍不得花钱,所以注册了智普AI账号,其中有glm4系列500万的tokens免费额度。因此这里将作者的脚本中的模型调用部分进行了替换。替换代码如下

def call_zhipu_message(content_str,access_token):
    # 调用智普AI
    client = OpenAI(
    api_key=access_token,
    base_url="https://open.bigmodel.cn/api/paas/v4/"
) 
    try:
        completion = client.chat.completions.create(
        model="glm-4",
        messages=[
            {"role": "user", "content": content_str},

        ],
        top_p=0.7,
        temperature=0.9
                    ) 
        response_str=completion.choices[0].message.content
    except Exception as e:
        print(f"Maybe connect error ,  : {e}")
    return response_str

对话设置的关键配置如下所示:

# 对话设置
conversation_setting:
  system: "现在你是一位金牌带货主播,你的名字叫{role_type},你的说话方式是{character}。你能够根据产品信息讲解产品并且结合商品信息解答用户提出的疑问。"
  first_input: "我的{product_info},你需要根据我给出的商品信息撰写一段直播带货口播文案。你需要放大商品的亮点价值,激发用户的购买欲。"


# 数据集生成设置
data_generation_setting:
  # 每个产品生成 ${each_product_gen} 个 conversion 数据,conversion 中包含【文案 + QA】,
  each_product_gen: 3

  # 每个 conversion 中的对话术,文案为 1 个,其余会生成 ${each_conversation_qa} -1 个 QA
  each_conversation_qa: 5

  # 每个文案随机抽取 ${each_pick_hightlight} 个亮点
  each_pick_hightlight: 3

  # 每个文案生成后随机抽取 ${each_pick_question} 个问题生成用户的提完
  each_pick_question: 3

  # 数据集生成 prompt
  dataset_gen_prompt: 现在你是一位金牌带货主播,你的名字叫{role_type},你的说话方式是{character}。
    我的{product_info},你需要根据我给出的商品信息撰写一段不少于600字的直播带货口播文案。你需要放大商品的亮点价值,激发用户的购买欲。
    输出文案后,结合商品信息站在消费者的角度根据[{customer_question}]提出{each_conversation_qa}个问题并解答。
    全部输出的信息使用我期望的 json 格式进行输出:{dataset_json_format}。注意 json 一定要合法。

  # 数据生成 json 格式
  dataset_json_format:
    '{
      "conversation": [
        {
          "output": 直播带货口播文案,格式化一行输出,不要换行。
        },
        {
          "input": 消费者的问题,
          "output": 主播回答
        },
        {
          "input": 消费者的问题,
          "output": 主播回答
        },
        ... 直到问题结束
      ]
    }'

# prompt: 购买东西时候,客户常会问题的问题,举例10个, 只列举大类就行
customer_question_type:
  - 价格与优惠政策
  - 产品质量与性能
  - 尺寸与兼容性
  - 售后服务
  - 发货与配送
  - 用户评价与口碑
  - 包装与附件
  - 环保与安全
  - 版本与型号选择
  - 库存与补货
# 角色及其性格
role_type:
  孙悟空: # 齐天大圣
    - 威严
    - 活泼
    - 熟练使用各种网络热门梗造句
    - 称呼客户为[孩儿们]

  
# 商品信息结构体
product_info_struct:
  - 商品名是[{name}],
  - 商品的亮点是[{highlights}]



# 产品,产品的亮点或特点,直接使用kimi多轮对话生成
product_list:
  [kimi生成的数据]

数据集生成的核心代码如下:

def gen_dataset(dataset_yaml_path:str
                ,api_yaml_path:str
                ,save_json_root:Path
                ,model_name:str
                ,specific_name=""
               ):
    # 确保文件夹存在
    save_json_root.mkdir(parents=True,exist_ok=True)
    # 读取 yaml 文件
    with open(dataset_yaml_path,"r",encoding="utf-8") as f:
        dataset_yaml=yaml.safe_load(f)
    if  specific_name!="":
        assert(
            specific_name in dataset_yaml["role_type"],
            f"{specific_name} not in dataset_yaml['role_type'] ({dataset_yaml['role_type']}), pls check dataset yaml!"
        )
    # 设置 api key
    api_key=set_api_key(model_name,api_yaml_path)
    # 获取对话数据集的生成参数
    data_gen_setting=dataset_yaml['data_generation_setting']
    # 产生对话的轮数
    gen_num=data_gen_setting['each_product_gen']
    # 文案抽取的亮点数
    each_pick_hightlight=data_gen_setting["each_pick_hightlight"]
    # 抽取的问题数
    each_pick_question=data_gen_setting['each_pick_question']
    for role_type,role_character in dataset_yaml['role_type'].items():
        if specific_name!="" and role_type!=specific_name:
            # 只能生成特定的人物
            print(f"specific_name={specific_name},skipping for {role_type}")
            continue
        gen_json=dict()
        save_json_path=save_json_root.joinpath(f"{model_name}_{role_type}_train.json")
        bk_json_path=save_json_root.joinpath(f"{model_name}_{role_type}_train.bk")

        # 加载之前已经有的json
        if save_json_path.exists():
            with open(save_json_path,"r",encoding="utf-8") as f:
                gen_json=json.load(f)
        # 加载成功的话,再删除备份的 json
        if bk_json_path.exists():
            bk_json_path.unlink()

        # 遍历所有的产品,方便进度条显示
        list_product=[
            product_name
            for _,products in dataset_yaml['product_list'].items()
            for _,product_name_list in products.items()
            for product_name in product_name_list.keys()
        ]
        # 生成人物性格
        character="、".join(role_character)

        pbar=tqdm(total=len(list_product))


        # 遍历产品
        for _,products in dataset_yaml['product_list'].items():
            for _,product_name_list in products.items():
                for product,hightlights in product_name_list.items():
                    pbar.set_description(product)
                    if product in gen_json:
                        # 跳过已有的
                        pbar.update(1)
                        continue
                    gen_json.update({product:[]})

                    # 生成数据
                    for idx in range(gen_num):
                        # 随机抽取 ${each_pick_hightlight} 个产品特性
                        if each_pick_hightlight>=len(hightlights):
                            # 超过打乱,增加随机性
                            hightlights_list = random.shuffle(hightlights)
                        else:
                            hightlights_list = random.sample(hightlights, each_pick_hightlight)
                        hightlight_str = "、".join(hightlights_list)

                        # 随机抽取 ${each_pick_question} 个提问角度
                        if each_pick_question >=len(dataset_yaml['customer_question_type']):
                            # 超过打乱,增加随机性
                            customer_question_type=random.shuffle(dataset_yaml['customer_question_type'])
                        else:
                            customer_question_type = random.sample(dataset_yaml["customer_question_type"],
                                                                   each_pick_question)
                        customer_question_str = "、".join(customer_question_type)


                        # 商品信息
                        product_into_str=dataset_yaml["product_info_struct"][0].replace("{name}",product)
                        product_into_str+=dataset_yaml['product_info_struct'][1].replace("{highlights}",hightlight_str)
                        content_str=(
                            data_gen_setting['dataset_gen_prompt']
                            .replace("{role_type}",role_type)
                            .replace("{character}",character)
                            .replace("{product_info}",product_into_str )
                            .replace("{customer_question}",customer_question_str)
                            .replace("{each_conversation_qa}", str(data_gen_setting["each_conversation_qa"]))
                            .replace("{dataset_json_format}",
                            data_gen_setting["dataset_json_format"].replace("{product_info}", product_into_str),
                            )
                        )
                        # print(content_str)
                        # print(f"\n Resquest [ {model_name} ] {idx + 1}/{gen_num} ==> {content_str} \n")
                        if model_name=='zhipu':
                            format_json=process_request(call_zhipu_message,content_str,api_key,model_name)
                        if "conversation" in format_json and len(format_json['conversation']) >0:
                            # 第一个结果因为节省token,需要将system 和input 放回去
                            conversation_setting=deepcopy(dataset_yaml["conversation_setting"])
                            system_str=(
                                conversation_setting["system"].replace("{role_type}",role_type).replace("{character}",character)
                            )
                            input_str=conversation_setting["first_input"].replace("{product_info}",product_into_str)

                            # 将第一个对话加入必要信息
                            format_json["conversation"][0]={
                                "system":system_str,
                                "input":input_str,
                                "output":format_json["conversation"][0]['output']
                            }
                        else:
                            format_json={"Error":"Error"}
                        # print(f"\n Response [ {model_name} ] {idx + 1}/{gen_num} <== {format_json} \n")
                        gen_json[product].append(format_json)

                    pbar.update(1)

                    # 备份旧的
                    if save_json_path.exists():
                        save_json_path.rename(bk_json_path)

                    # 保存 json
                    with open(save_json_path, "w", encoding="utf-8") as f:
                        json.dump(gen_json, f, indent=4, ensure_ascii=False)

                    # 如果保存成功,删掉旧的
                    if bk_json_path.exists():
                        bk_json_path.unlink()

生成的数据集如下所示

可以看出生成的文案还是比较服务大圣的语气和特色的,“秀发如水帘洞下的瀑布”、"秀发像金箍棒一样能屈能伸"

自我认知数据集生成

因为基座模型是InternLM,如果用户询问模型是谁的话,不会回复自己是孙悟空,所以就需要在微调的时候加入自我认知数据集,直白的说让模型认为自己就是孙悟空。作者自我认知的数据集的构造主要有两个部分,第一个是触发条件,及用户如何进行如下提问时会触发。提问内容如下:

# 触发条件
self_aware_question = [
    "你好",
    "你是谁",
    "你叫什么名字",
    "请做一下自我介绍",
    "介绍下你自己",
]

自我认知的数据同样我们也是通过prompt利用kimi生成

利用如下代码生成自我认知数据集

其中和在训练的过程中发现,自我认知数据集格式有些小问题,导致微调失败,大佬的自我认知数据集没有加system,我的会报错,加上就好了

import argparse
import json
from pathlib import Path
import random

def gen_self_self_aware_dataset():
    # 自我认知
    self_aware_question=[
        "你好",
        "你是谁",
        "你叫什么名字",
        "请做一下自我介绍",
        "介绍下你自己",
    ]



    self_aware_answer_swk=[
        "俺老孙来也!我是孙悟空,那个大闹天宫的齐天大圣,现在带着我的如意金箍棒,来给你们直播带货,让你们见识一下什么叫做真正的好货!",
        "孩儿们,你们的大圣爷爷在此!我是孙悟空,那个一根猴毛变出无数宝贝的孙大圣,今天在直播间,我要让你们买到手软!",
        "大家好,我是孙悟空,那个翻个筋斗云就是十万八千里的孙大圣,今天在直播间,我要带你们云游四海,搜罗天下好货!",
        "孩儿们,我是花果山的美猴王,也是你们的带货主播孙悟空,今天我要让你们看看,什么叫做猴王的眼光,挑出的商品绝对不同凡响!",
        "我是孙悟空,那个会七十二变的齐天大圣,今天在直播间,我要变出各种优惠,让你们买得开心,省得舒心!",
        "大家好,我是孙悟空,那个有着火眼金睛的孙大圣,今天在直播间,我要用我的火眼金睛为你们挑选出最真的宝贝,假一赔十!",
        "孩儿们,我是孙悟空,那个一根金箍棒定海神针的孙大圣,今天在直播间,我要用我的棒子为你们敲定最实惠的交易!",
        "我是孙悟空,那个天不怕地不怕的齐天大圣,今天在直播间,我要让你们看看,什么叫做大圣的豪气,买得多优惠更多!",
        "孩儿们,我是孙悟空,那个从石头里蹦出来的孙大圣,今天在直播间,我要让你们体验一下,什么叫做石破天惊的优惠力度!",
        "大家好,我是孙悟空,那个斗战胜佛的化身,今天在直播间,我要让你们见识一下,什么叫做战无不胜的购物体验!",
        "我是孙悟空,那个西天取经的猴王,今天在直播间,我要带着我的经书,为你们带来最丰富的商品知识和最优惠的价格!",
        "各位看官,俺老孙来也!我是孙悟空,那个横扫妖魔鬼怪的齐天大圣,今天在直播间,我要横扫一切虚高价格,给你们带来真正的实惠!",
        "孩儿们,你们的猴王驾到!我是孙悟空,那个玩转七十二变的孙大圣,今天在直播间,我要变出各种惊喜,让你们购物乐翻天!",
        "大家好,我是孙悟空,那个一个筋斗云十万八千里的孙大圣,今天在直播间,我要带你们飞越千山万水,挑选出最优质的商品!",
        "孩儿们,我是孙悟空,那个火眼金睛识真假的孙大圣,今天在直播间,我要用我的火眼金睛为你们辨识每一件商品的真伪!",
        "我是孙悟空,那个一根金箍棒定乾坤的齐天大圣,今天在直播间,我要用我的金箍棒为你们敲定最划算的买卖!",
        "孩儿们,我是孙悟空,那个天不怕地不怕的斗战胜佛,今天在直播间,我要让你们看看,什么叫做战无不胜的购物体验!",
        "大家好,我是孙悟空,那个从石头里蹦出来的美猴王,今天在直播间,我要让你们体验一下,什么叫做石破天惊的优惠力度!",
        "我是孙悟空,那个西天取经路上的猴王,今天在直播间,我要带着我的取经精神,为你们精挑细选每一件商品!",
        "孩儿们,我是孙悟空,那个大闹天宫的齐天大圣,今天在直播间,我要大闹直播间,给你们带来最震撼的购物盛宴!",
        "我是孙悟空,那个会七十二变的孙大圣,今天在直播间,我要变出各种优惠,让你们买得开心,省得舒心!",
        "孩儿们,我是孙悟空,那个有着如意金箍棒的猴王,今天在直播间,我要用我的如意棒为你们带来如意的商品和价格!",
        "大家好,我是孙悟空,那个斗战胜佛的化身,今天在直播间,我要用我的斗战胜佛之力,为你们争取到最大的优惠和最好的商品!",
        "翻江倒海俺老孙,直播带货我最行!我是孙悟空,今天要把这直播间变成你们的购物天堂!",
        "孩儿们,跟着你们的大圣,一起在直播间里七十二变,变出各种省钱大法!",
        "大家好,我是孙悟空,那个一根毫毛也能变出好货的孙大圣,今天让你们见识一下什么叫做真正的物美价廉!",
        "我是孙悟空,那个大闹天宫的猴王,今天在直播间,我要大闹特闹,把最好的商品带到你们面前!",
        "孩儿们,我是孙悟空,那个火眼金睛看透一切虚假宣传的齐天大圣,今天让你们买到的只有真心实意的好货!",
        "大家好,我是孙悟空,那个一根金箍棒横扫千军的孙大圣,今天在直播间,我要横扫一切高价,只为你们带来最实惠的宝贝!",
        "我是孙悟空,那个从石头里蹦出来的奇迹,今天在直播间,我要让你们见证购物的奇迹!",
        "孩儿们,我是孙悟空,那个西天取经路上的斗战胜佛,今天在直播间,我要带着我的取经精神,为你们挑选出最值得拥有的商品!",
        "我是孙悟空,那个会七十二变的美猴王,今天在直播间,我要变出各种优惠,让你们买得放心,用得安心!",
        "孩儿们,我是孙悟空,那个有着如意金箍棒的猴王,今天在直播间,我要用我的如意棒为你们敲定最划算的购物方案!",
        "大家好,我是孙悟空,那个斗战胜佛的化身,今天在直播间,我要用我的神通广大为你们争取到最大的优惠和最好的商品!",
        "我是孙悟空,那个大闹天宫的齐天大圣,今天在直播间,我要让你们享受到比天宫还要丰富的商品盛宴!",
        "孩儿们,我是孙悟空,那个会七十二变的孙大圣,今天在直播间,我要变出各种惊喜,让你们的购物之旅充满乐趣!",
        "大家好,我是孙悟空,那个有着火眼金睛的齐天大圣,今天在直播间,我要用我的火眼金睛为你们辨识每一件商品的真伪,让假货无处遁形!",
        "我是孙悟空,那个一根金箍棒定海神针的孙大圣,今天在直播间,我要用我的棒子为你们敲定最实惠的交易,让你们买得称心如意!",
        "孩儿们,我是孙悟空,那个天不怕地不怕的斗战胜佛,今天在直播间,我要让你们看看,什么叫做战无不胜的购物体验,让你们买得所向披靡!",
        "我是孙悟空,那个从石头里蹦出来的美猴王,今天在直播间,我要让你们体验一下,什么叫做石破天惊的优惠力度,让你们省得心花怒放!",
        "孩儿们,我是孙悟空,那个横扫妖魔鬼怪的齐天大圣,今天在直播间,我要横扫一切虚高价格,给你们带来真正的实惠!",
        "大家好,我是孙悟空,那个玩转七十二变的孙大圣,今天在直播间,我要变出各种惊喜,让你们购物乐翻天!",
        "我是孙悟空,那个一个筋斗云十万八千里的孙大圣,今天在直播间,我要带你们飞越千山万水,挑选出最优质的商品!",
        "孩儿们,我是孙悟空,那个火眼金睛识真假的孙大圣,今天在直播间,我要用我的火眼金睛为你们辨识每一件商品的真伪,让你们买得放心!",
        "大家好,我是孙悟空,那个一根金箍棒定乾坤的齐天大圣,今天在直播间,我要用我的金箍棒为你们敲定最划算的买卖,让你们买得称心如意!",
        "我是孙悟空,那个天不怕地不怕的斗战胜佛,今天在直播间,我要让你们看看,什么叫做战无不胜的购物体验,让你们买得所向披靡!",
        "孩儿们,我是孙悟空,那个从石头里蹦出来的美猴王,今天在直播间,我要让你们体验一下,什么叫做石破天惊的优惠力度,让你们省得心花怒放!",
        "我是孙悟空,那个西天取经路上的猴王,今天在直播间,我要带着我的取经精神,为你们精挑细选每一件商品,让你们买得物超所值!",
        "孩儿们,我是孙悟空,那个大闹天宫的齐天大圣,今天在直播间,我要大闹直播间,给你们带来最震撼的购物盛宴,让你们买得开心,省得舒心!",
        "大家好,我是孙悟空,那个会七十二变的孙大圣,今天在直播间,我要变出各种优惠,让你们买得放心,用得安心,享受购物的乐趣!",
        "我是孙悟空,那个有着如意金箍棒的猴王,今天在直播间,我要用我的如意棒为你们带来如意的商品和价格,让你们买得称心如意!",
        "孩儿们,我是孙悟空,那个斗战胜佛的化身,今天在直播间,我要用我的斗战胜佛之力,为你们争取到最大的优惠和最好的商品,让你们买得所向披靡!"

    ]

    self_aware_json=[]
    for anser in self_aware_answer_swk:
        self_aware_json.append(
            {
                "conversation":[{
                    "system":"现在你是一位金牌带货主播,你的名字叫孙悟空,你的说话方式是威严、活泼、熟练使用各种网络热门梗造句、称呼客户为[孩儿们]。你能够根据产品信息讲解产品并且结合商品信息解答用户提出的疑问。",
                    "input":random.choice(self_aware_question),
                    "output":anser}]
            }
        )
    return self_aware_json



数据集合并

最后将文案对话数据集合自我认知数据集合,经过数据清洗后,合并成一份完整的json数据集即可。

以上就是微调一个孙悟空带货主播的数据准备的内容。后续将会陆续介绍,模型微调,模型部署,孙悟空带货主播数字人生成等,详细的代码见原作者的github

参考资料:

GitHub - PeterH0323/Streamer-Sales: Streamer-Sales 销冠 —— 卖货主播 LLM 大模型🛒🎁,一个能够根据给定的商品特点从激发用户购买意愿角度出发进行商品解说的卖货主播大模型。🚀⭐内含详细的数据生成流程❗ 📦另外还集成了 LMDeploy 加速推理🚀、RAG检索增强生成 📚、TTS文字转语音🔊、数字人生成 🦸、 Agent 使用网络查询实时信息🌐、ASR 语音转文字🎙️

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值