一键实现推理Qwen、BaiChuan、Llama模型代码

import time
import torch
from addict import Dict
from threading import Thread
from epointml.utils import elog
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence, Tuple
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer

# AutoModelForCausalLM 自回归文本生成预训练模型

class run_dif_llm(object):
    def __init__(self, args, message, logger=None):
        self.args = args
        self.model_name = self.args.model_name
        self.model_path = self.args.model_path
        self.device = self.args.device      
        self.model_type = self.args.model_type
        self.message = message
        self.logger = elog() if logger is None else logger
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
    
    def get_model(self):
        if self.model_name=='Llama':
            self.logger.info(" loading model in Llama")
            self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True).eval()
            if self.model_type =='Chat':
                self.run_LLM_Chat(self.model, self.device)
            elif self.model_type == 'Base':
                self.run_LLM_Base(self.model, self.device)
        elif self.model_name=='BaiChuan' or self.model_name=='QianWen':
            self.logger.info(" loading model in BaiChuan or QianWen")
            self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype='auto', device_map="auto", trust_remote_code=True).eval()
            if self.model_type =='Chat':
                self.run_LLM_Chat(self.model, self.device, self.args)
            elif self.model_type == 'Base':
                self.run_LLM_Base(self.model, self.device)
        
    # chat模型
    def run_LLM_Chat(self, model, device, args):
        text = self.tokenizer.apply_chat_template(
            self.message,
            tokenize=False,
            add_generation_prompt=True
        )

        model_inputs = self.tokenizer([text], return_tensors="pt").to(device)

        generated_ids = model.generate(
            model_inputs.input_ids,
            # max_length = args.max_length,
            min_length = args.min_length,
            do_sample = args.do_sample,
            num_beams  = args.num_beams,
            # early_stopping = args.early_stopping,
            temperature = args.temperature,
            top_k = args.top_k,
            top_p = args.top_p,
            repetition_penalty = args.repetition_penalty,
            no_repeat_ngram_size  = args.no_repeat_ngram_size,
            max_new_tokens = args.max_new_tokens,
            # length_penalty = args.length_penalty,
        )
        
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(response)
    
    # base模型
    def run_LLM_Base(self, model, device):
        inputs = self.tokenizer(self.message, return_tensors='pt')
        inputs = inputs.to(device)
        pred = model.generate(**inputs)
        print(self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
    
    # 调用实例化类时运行该代码
    def __call__(self, args):
        if args.model_type=='Chat':
            if args.model_name=='BaiChuan' or args.model_name=='QianWen':
                self.get_model()
            elif args.model_name=='Llama':
                self.get_model()
        elif args.model_type=='Base':
            if args.model_name=='BaiChuan' or args.model_name=='QianWen':
                self.get_model()
            elif args.model_name=='Llama':
                self.get_model()
        
    



if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    
    
    # 添加参数并设置可选值
    parser.add_argument('--model_name', default='Baichuan', help='选择推理模型名称')     # QianWen  Llama  BaiChuan
    parser.add_argument('--model_type', default='Chat', help='选择推理模型类型')        # Chat  Base
    
    # model.generate()内参数设置,主要max_length、min_length、do_sample、top_k、top_p、repetition_penalty
    
    parser.add_argument('--max_length', type=int, default=64, help='控制生成文本的最大长度, 一旦达到这个长度, 生成过程就会停止')
    parser.add_argument('--min_length', type=int, default=10, help='设置生成文本的最小长度, 确保生成的文本不会过短, 默认为10')
    parser.add_argument('--do_sample', type=bool, default=False, help='是否开启采样, 默认是False, 即贪婪找最大条件概率的词')
    parser.add_argument('--num_beams', type=int, default=1, help='默认是1, 不进行beam search')
    parser.add_argument('--early_stopping', type=bool, default=False, help='是否在至少生成num_beams个句子后停止beam search, 默认是False')
    parser.add_argument('--temperature', type=float, default=1.0, help='默认是1.0, 温度越低(小于1), softmax输出的贫富差距越大,温度越高,softmax差距越小')
    parser.add_argument('--top_k', type=int, default=50, help='top-k-filtering算法保留多少个最高概率的词作为候选, 默认50')
    parser.add_argument('--top_p', type=float, default=1.0, help='已知生成各个词的总概率是1(即默认是1.0)如果top_p小于1, 则从高到低累加直到top_p, 取这前N个词作为候选')
    parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty: 默认是1.0, 重复词惩罚')
    parser.add_argument('--no_repeat_ngram_size', type=int, default=0, help='用于控制重复词生成,默认是0,如果大于0,则相应N-gram只出现一次')
    parser.add_argument('--max_new_tokens', type=int, default=512, help='生成的最大长度')
    parser.add_argument('--length_penalty', type=float, default=1.2, help='长度惩罚, 默认是1.0')
    
    parser.add_argument('--device', type=str, default='cuda', help='GPU/CPU')

    
    # 解析参数
    args = parser.parse_args()
    args.model_path = '/var/DockerVolumes/data/llmmodel/Baichuan2-7B-Chat'
    # args.model_path = '/var/DockerVolumes/data/llmmodel/Qwen2-7B'
    # args.model_path = '/var/DockerVolumes/data/llmmodel/llama-3-sqlcoder-8b'
    
    # 针对chat模型
    if args.model_type=='Chat':
        message = [
            {
                "role":"user",
                "content":"介绍一下李白"
            }
        ]
        dif_llm = run_dif_llm(args, message)
        result = dif_llm(args)
    # 针对base模型
    elif args.model_type=='Base':
        message = '登鹳雀楼->王之涣\n夜雨寄北->'
        dif_llm = run_dif_llm(args, message)
        result = dif_llm(args)

    
    
    # 终端运行样例,文件名为inference_llm.py
    # dif_llm = run_dif_llm(args, message)
    # BaiChuan-Chat      python inference_llm.py --model_name BaiChuan --model_type Chat
    # QianWen-Chat       python inference_llm.py --model_name QianWen --model_type Chat
    # llama-Chat         python inference_llm.py --model_name Llama --model_type Chat
    
    
    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值