基于lora对glm4-9b-chat模型的微调和推理

训练

from datasets import Dataset
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig
import torch
from peft import LoraConfig, TaskType, get_peft_model

MODEL_PATH = '/data/yangjianxie/glm-4-9b-chat'

# df = pd.read_excel('电池表-地点的扩充(2).xlsx')#read_json('data/result.json')
df = pd.read_csv('data2.csv',encoding='gbk')
# print(df.head(1000))
ds = Dataset.from_pandas(df)
# ds = df.head(1000)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

def process_func(example):
    # MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer((f"[gMASK]<sop><|system|>\n你是一名sql专家。<|user|>\n"
                            f"{example['instruction']+str(example['input'])}<|assistant|>\n"
                            ).strip(),
                            add_special_tokens=False)
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    # if len(input_ids) > MAX_LENGTH:  # 做一个截断
    #     input_ids = input_ids[:MAX_LENGTH]
    #     attention_mask = attention_mask[:MAX_LENGTH]
    #     labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)
print(tokenized_id)

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.enable_input_require_grads()



config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],  # 现存问题只微调部分演示即可
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

model = get_peft_model(model, config)
model.print_trainable_parameters()


args = TrainingArguments(
    output_dir="glm4-lora-test2",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=50,
    num_train_epochs=5,
    save_steps=50,
    learning_rate=1e-5,
    save_on_each_node=True,
    gradient_checkpointing=True
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

data2有三列:instruct,inout,output

instruct:

       ### Database Schema


       
          CREATE TABLE `data_info_by_city` (
              `city` varchar(80) DEFAULT NULL COMMENT '城市',
              `area` varchar(80) DEFAULT NULL COMMENT '区域',
              `province` varchar(80) DEFAULT NULL COMMENT '省份',
              `cumulative_charging_swapping_times` bigint DEFAULT NULL COMMENT '累计充换电次数',
              `cumulative_charging_times` bigint DEFAULT NULL COMMENT '累计充电次数',
              `cumulative_swapping_times` bigint DEFAULT NULL COMMENT '累计换电次数',
              `daily_real_time_charging_swapping_times` bigint DEFAULT NULL COMMENT '今天实时充换电次数',
              `daily_real_time_charging_times` bigint DEFAULT NULL COMMENT '今天实时充电次数',
              `daily_real_time_swapping_times` bigint DEFAULT NULL COMMENT '今天实时换电次数',
              `battery_safety_rate`  bool DEFAULT NULL COMMENT '电池安全占比',
              `battery_non_safety_rate`  bool DEFAULT NULL COMMENT '电池非安全/风险/异常预警占比',
              `reduce_total_mileage` varchar(80) default null comment '节省总里程',
              `reduce_carbon_emissions` varchar(80) default null comment '减少碳排放',
             `battery_riding_mileage` bigint default null comment '电池当天行驶里程',
             `battery_non_safety_times` bigint default null comment '当天电池不安全/风险/异常预警次数'

        );

    CREATE TABLE `user_rider_info` (
      `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
      `city` varchar(80) DEFAULT NULL COMMENT '中文城市名称,如杭州市、北京市',
      `area` varchar(80) DEFAULT NULL COMMENT '中文区域名称,如余杭区、西湖区',
      `province` varchar(80) DEFAULT NULL COMMENT '中文省份名称,如浙江省、江苏省',
      `gender` bigint default null comment '性别 0表示男性,1表示女/女性/女生',
      `age` bigint default null comment '年龄',
      `hd_count` bigint default null comment '换电次数',
      `daily_hd_count` bigint default null comment '每天换电次数'
      )  comment '骑手表';

    CREATE TABLE `battery` (
      `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
      `city` varchar(80) DEFAULT NULL COMMENT '中文城市名称,如杭州市、北京市',
      `area` varchar(80) DEFAULT NULL COMMENT '中文区域名称,如余杭区、西湖区',
      `province` varchar(80) DEFAULT NULL COMMENT '中文省份名称,如浙江省、江苏省',
      `is_qishou` bool DEFAULT NULL COMMENT ' 0表示不在骑手中, 1表示在骑手中',
      `is_guizi`  bool DEFAULT NULL COMMENT ' 0表示不在电柜中, 1表示在电柜中',
      `type` bigint default null comment '1表示48伏/48V类型电池, 2表示60伏/60V类型电池, 3表示48MAX类型电池, 4表示60MAX类型电池',
      `status` bigint DEFAULT NULL COMMENT '0代表电池使用中,1代表电池充电中,2代表电池已满电',
      ) comment '电池表';

    CREATE TABLE `hdg_info` (
      `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
      `city` varchar(80) DEFAULT NULL COMMENT '中文城市名称,如杭州市、北京市',
      `area` varchar(80) DEFAULT NULL COMMENT '中文区域名称,如余杭区、西湖区',
      `province` varchar(80) DEFAULT NULL COMMENT '中文省份名称,如浙江省、黑龙江省'
      ) comment '电柜表';
    

       基于提供的database schema信息,回答问题:骑手手中有多少电池?

       input:空
       output:SELECT count(*) FROM battery WHERE is_qishou = 1 

推理

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import pandas as pd
from datasets import Dataset
from tqdm import tqdm


df = pd.read_csv('data.csv', encoding='gbk').head(1000)
# index = 40008
# print('label:',df['output'][index])
# print('data:',df.iloc[index,:])

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

MODEL_PATH = '/data/yangjianxie/glm-4-9b-chat/'
lora_path = '/home/yangjianxie/text2sql/zheli/glm4-lora-test/checkpoint-60'

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
# print(model)
# print(sum(param.numel() for param in model.parameters()))

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path)
# print(model)
# print(sum(param.numel() for param in model.parameters()))

# prompt = df['instruction'][index]
# data = df.iloc[20000:21000,:]
#
# result = []
# for query in tqdm(data['instruction']):
qs = ['有多少电池在骑手手中','有多少电池不在骑手手中','现在有多少电池在柜子里','现在有多少电池不在换电柜里','有多少满电电池','正在充电的电池有多少','充满电的电池有多少',
      '有多少48伏的电池','有多少60伏的电池','有多少48max类型的电池','有多少60max类型的电池']
for q in qs:
    print(q)
    prompt = df['instruction'][0].replace('查询北京市有多少48MAX电池?', q)
    inputs = tokenizer.apply_chat_template([{"role": "system", "content": "你是一名sql专家。"},{"role": "user", "content": prompt}],
                                           add_generation_prompt=True,
                                           tokenize=True,
                                           return_tensors="pt",
                                           return_dict=True
                                           ).to('cuda')


    gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        # result.append(tokenizer.decode(outputs[0], skip_special_tokens=True))

        print('inffer:',tokenizer.decode(outputs[0], skip_special_tokens=True))
        print('over')
  • 6
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值