第八节 LLaVA模型CLI推理构建custom推理代码Demo


前言

我在第七节介绍了cli.py推理源码解读,而我也因项目需要构建了推理demo,我们是用来自动生成标签和推理需要。想了想,我还是用一节将我的代码记录于此,供有需求读者使用。本节,介绍更改cli.py代码,实现一张图像推理、也为需要grounding的读者提供如何在图上给出目标box。


一、parser 参数设定

为什么我要单独介绍参数设定?因为它很重要,正确的设定会减少模型错误概率。我将介绍三个部分设定,一个是使用lora权重,一个是合并权重,最后一个是使用量化方式。

1、lora权重推理

我们训练模型多数使用lora训练,而未将lora训练结果合并的权重加载方式的方法。如果我们是使用自己训练方法,可以使用如下方式给出参数:

    parser.add_argument("--model-path", type=str, default="/extend_disk/disk3/tj/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch_clean2/checkpoint-10200")
    parser.add_argument("--model-base", type=str, default="/extend_disk/disk3/tj/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5")

如果我们是使用LLaVA自带lora方式,model-base基本不变,只需将model-path="/LLaVA/checkpoint/llava-v1.5-13b-lora",而权重下载我之前文章也介绍。

2、非lora权重推理

我们训练模型使用lora方法保存,想调用非lora方式,就需要将其转换。我们这里不说转换方法,给出非lora的权重加载方式。那这里只介绍官方给出权重加载参数设定,如下:

   parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b")
   parser.add_argument("--model-base", type=str, default=None)

3、量化权重推理

量化只需打开load-8bit或load-4bit参数,但量化必须是非lora权重加载方式,其代码如下:

   parser.add_argument("--load-8bit", action="store_true")
    # parser.add_argument("--load-4bit", default=True)
    parser.add_argument("--load-4bit", action="store_true")

当然量化显存占用测试,我们以LLAVA-13b量化显存测试:
不量化推理显存占用:28.4G
8bit量化推理显存占用:16.6G
4bit量化推理显存占用:10.6G

4、实验总结

我测试官方提供lora与非lora权重,我发现非lora效果会比lora好。当然这是我测试工程数据得到结论,只做参考。

二、初始化模型

我不在介绍,如下代码:

def llava_init(args):

    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)


    return tokenizer, model, image_processor, context_len,model_name

我想说,每个权重名称需包含v1字符,以便后续对话加载方式。

三、模型推理

模型推理,我将提示改成列表方式,我也对有框目标的文本预测做了图上画框操作。其它基本都是流程,我不在解读了。

四、完整代码Demo

最后,我给出完整的Demo,可以直接复制粘贴即可使用。若还想按照自己custom方式,读者也可根据我提供的方法来修改。其完整带阿米如下:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


import argparse
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

def img_drawingbox(image,conversation_info,res_img_path=None):
    from PIL import Image, ImageDraw, ImageFont
    import re

 
    width, height = image.size

    draw = ImageDraw.Draw(image)
   
   
    box_lst = []
    for info in conversation_info['conversations']:
        value = info['value']
        gpt = info['from']
        if gpt == 'gpt':
            result = re.search(r'\[(.*?)\]', value)
            if result:
                content_in_brackets = result.group(1)
                # 将提取的内容转换为浮点数列表
                float_list = [float(num) for num in content_in_brackets.split(',')]
                if float_list not in box_lst:
                    box_lst.append(float_list)
    

    if len(box_lst)>0:
        for b in box_lst:
            if  len(b)==4:
                x1,y1,x2,y2 = b[0]*width,b[1]*height,b[2]*width,b[3]*height
                x1,y1,x2,y2=max(0,int(x1)),max(0,int(y1)),min(width,int(x2)),min(y2,height) 
                box=(x1,y1,x2,y2)
                # 绘制矩形框
                draw.rectangle(box, outline="red", width=2)  # 红色边框,宽度为2像素

    if res_img_path is not None:
        image.save(res_img_path,encoding="utf-8")
    return image



def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def llava_init(args):

    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)


    return tokenizer, model, image_processor, context_len,model_name


def llava_infer(image,test_prompt,args,tokenizer, model, image_processor, model_name='llava_v1.5'):
   
    assert isinstance(test_prompt,list), "test_prompt提示文本必须是问题构成的列表!"

    if 'llama-2' in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

 
    conversations_json = {'conversations':[]}
    conv = conv_templates[args.conv_mode].copy()
    if "mpt" in model_name.lower():
        roles = ('user', 'assistant')
    else:
        roles = conv.roles

    width, height = image.size
    # Similar operation in model_worker.py
    image_tensor = process_images([image], image_processor, model.config)
    if type(image_tensor) is list:
        image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    for i ,inp in enumerate(test_prompt):
        
        conversations_json['conversations'].append({"from": "human","value":inp})

        if i==0:
            # first message
            if model.config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else: #
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp  # 走这步变成  <image>\n描述图像内容
            conv.append_message(conv.roles[0], inp)
        else:
            # later messages # 后面循环对话添加内容
            conv.append_message(conv.roles[0], inp)
        
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]  # '</s>' ,这个是每句结束标志
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        # 下面开始走模型
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                streamer=streamer,
                use_cache=True,
                stopping_criteria=[stopping_criteria])

        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()  # ouput_ids中去除input_ids位置prompt
        conv.messages[-1][-1] = outputs

        
        conversations_json['conversations'].append({"from": "gpt","value":outputs.replace('</s>','')})
    
        
    print(conversations_json)

    img_drawingbox(image,conversations_json,res_img_path=None)


    return conversations_json




def parse_args():
    parser = argparse.ArgumentParser()
    ## 直接使用合并后的模型进行推理
    # parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b")
    # parser.add_argument("--model-base", type=str, default=None)
    ## lora推理方法
    parser.add_argument("--model-path", type=str, default="/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch/checkpoint-10200")
    parser.add_argument("--model-base", type=str, default="/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5")
     
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    # parser.add_argument("--load-4bit", default=True)
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    
    args=parse_args()

    tokenizer, model, image_processor, context_len,model_name=llava_init(args)
    
    img_path = '/LLaVA/llava/serve/examples/1.jpg'

    images = load_image(img_path)

    test_prompt = ["图中是否有城市管理相关目标?若有,请提供相应坐标。"]
    predect_information_dict = llava_infer(images,test_prompt,args,tokenizer, model, image_processor, model_name)



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

tangjunjun-owen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值