多模态大模型Qwen-VL和MiniCPM-Llama3-V-2_5初体验

目录

一、QwenVL和MiniCPM-Llama3-V-2_5模型简介

Qwen-VL-Chat

MiniCPM-Llama3-V-2_5

二、QwenVL和MiniCPM-Llama3-V-2_5原理简析

主流视觉特征和LLM结合

Qwen-VL模型原理

MiniCPM-Llama3-V-2_5模型原理

推理代码

QwenVL

MiniCPM-Llama3-V-2_5

三、QwenVL和MiniCPM-Llama3-V-2_5功能测试和效果

未微调的效果

微调后的效果

demo展示

四、思考



        借着公司做视觉方向的业务,学习和体验了一下多模态大模型。主要是对Qwen-VL和MiniCPM-Llama3-V-2_5做预研,了解视觉特征是怎么融合进LLM大模型的,同时验证一下上述两个模型在OCR能力上有那些惊艳的效果。本篇博客对模型进行简介、探究了一下视觉特征和LLM结合方式、以及微调实验的一些结论。

一、QwenVL和MiniCPM-Llama3-V-2_5模型简介

Qwen-VL-Chat

       QwenVL这个号称是国内最好的多模态大模型,阿里通义千问系列多模态大模型之一。QwenVL系列有3个大模型,分别是Qwen-VL-Chat&Qwen-VL-Plus & Qwen-VL-Max,其中Qwen-VL-Chat开源了代码以及模型权重,而Qwen-VL-Plus & Qwen-VL-Max这两个效果更加的模型,并未开源,但是可以通过🤗🤖网页端APP 和 API访问,而我们重点关注的是开源的Qwen-VL-Chat,后文简称QwenVL。首先模型架构层面,LLM来自QWen-7B,visionTransformer采用s ViT-bigG,总体参数大约9.5B,fp16保持的权重文件包总计19G。它具备很多视觉方面的能力,docVqa(文档理解),图片问答、图片理解、OCR多图、多图问答、多轮问答以及图片box理解定位等,可以看论文中的举例(Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond):

训练过程如下:

第一阶段:冻结QwenLLM模型权重,训练Q_Transformer模型权重以及Vit模型权重

第二阶段:所有的模型权重都放开进行多任务预训练

第三阶段:冻结Vit模型权重,其他的部分权重进行有监督的下游任务训练。

注意的是他们的训练数据情况:

中文的只占了1.4B中的23.24%,其中还有220M约67.7%的私有数据,其实从后续效果体验来看,QwenVl在中文的OCR效果并不惊艳,从网页上体验的Qwen-VL-Plus & Qwen-VL-Max效果确实很好,猜测是训练数据和模型规模有优化。

MiniCPM-Llama3-V-2_5

这个模型地清华智谱和面壁智能合作出品的多模态大模型,它的文本大模型是基于llama3模型架构,visionTransformer是siglip-so400m模型,模型参数总量共计8B,采用fp16存储的模型参数占用16G空间。并没有找到官方论文,下面就贴一下其演示的多模态的能力:

文档理解

表格理解

以及多轮对话等等。

二、QwenVL和MiniCPM-Llama3-V-2_5原理简析

主流视觉特征和LLM结合

        首先需要理解一下文本大模型LLM和视觉模型是怎么结合在一起的,目前主流的方法是:图像特征和文本特征融合,然后输入到LLM大模型中进行训练和推理。具体到融合的策略而言,一种是采用Query-Transformer,一种是MLP。本文中涉及到的Qwen-VL和MiniCPM-Llama3-V-2_5都是基于Query-Transformer来进行特征融合的。现阶段之前的大模型大多都是采用Query-Transformer来压缩图像信息,减少图片tokens占用的数量,往后应该是MLP越来越受欢迎的,具体的分析可以看知乎上的专业解析多模态大语言模型(MLLM)为什么最近的工作中用BLIP2中Q-Former结构的变少了,这里我就不分析了,没有做过实验也分析不了哈。融合原理如下图(引自知乎文章多模态大模型:视觉模型与LLM的结合之路(四))

从图中可以得出图像和文本大模型的结合方式

1、图片经过视觉Encoder模型得到图片的视觉特征img_emb

2、img_emb经过一个压缩变换层adapter,把img_emb的维度和文本特征prompt emb对齐(矩阵的最后一维相等),同时为了减少对齐后img_emb占用太多的token位置,把img_emb压缩到一个固定的token数量

3、img_emb和prompt emb对齐后直接concat起来,输入到LLM

Qwen-VL模型原理

首先上模型结构图

Qwen-VL模型LLM使用Qwen-7B预训练模型,32层QWenBlock;视觉编码器使用openclip的ViT-bigG预训练模型,48层TransformerBlock;adapter中含有256个query node,一层attention,同时添加了2D的位置编码,注意到q和k的位置编码是不一样的,直接原因是q(256)和k(1024)的长度不一样。

模型的输入中是直接把图片路径img_path和文本prompt作为一个整体输入的,后续的流程中从输入中把img_path提取出来,然后读取图片得到img,输入到visual模块中(VisionTransformer)。visual模块包含了提取img特征的TransformerBlock以及作为图片特征压缩和变换的resampler模块(也就是adapter),img经过TransformerBlock后得到图片的高阶特征img_feature,再经过resampler就把img_feature和text_embedding维度做了对齐,同时也压缩到固定的token上。text_embedding和img_embedding直接concat就得到text和img融合后的特征,如上图所示中的,粉红色方块有256个代表的是一张img被压缩和变换的token_embedding。整体上理解了这个流程也就差不多理解了Qwen-VL是怎么把视觉特征和文本特征融合在一起然后输入到LLM模块中进行处理的,至于ViT-bigG预训练模型、Qwen-7B以及adapter更多细节本文不予讨论。

text_embeding和img_embedding(对齐后的)融合代码:

        hidden_states = self.drop(hidden_states).clone()
        if fake_images is not None:
            hidden_states = hidden_states + images.mean()*0
        elif images is not None:
            for idx, (i, a, b) in enumerate(img_pos):
                hidden_states[i][a + 1 : b] = images[idx]

直接把text_embedding中分配给img特征的位置直接赋值成相应的img_embedding。

为了更加清晰的看清楚整个text_img和img_融合过程,下面展示一下模型输入后的中间产物以及模块的矩阵维度变化。

文本经过tokenizer后的产物

<img></img>

<img>/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/image_with_text.jpg</img>

<img>img_path</img>占用258个token位置,其中img_path占用256个token位置,后续这些位置的embedding直接按照上述代码中的方式直接替换为对齐后的img_embedding。

输入到embedding的示意图如下

把原始输入tokenize后得到的input_ids,从中decode得到img_path,读取图片,经过resize、conv2d to patches一系列操作,把img tokenize化,可以输入到visionTransformer预训练模型中提取img_feature,最后经过adpater中的线性层把维度和文本特征维度对齐,才通过crossattention把img_feature压缩到固定token数量上,减少img特征对token的占用。

MiniCPM-Llama3-V-2_5模型原理

模型结构和Qwen-VL相比较大差不差,多模态的融合思想都是一样的,都是把img通过变换压缩到固定的token数,而且都是采用Query-Transformer来融合的。不同点是LLM模块MiniCPM-Llama3-V-2_5采用的llama3;visionTransformer模块MiniCPM-Llama3-V-2_5采用的是siglip-so400m预训练模型;adapter几乎相同都是img特征经过Linear层和文本特征对齐后,在经过cross_attn模块压缩。

如上图,模型结构大致一致,resampler中的cross_attn计算的时候,只有K加了位置编码。一个比较大的不同是MiniCPM-Llama3-V-2_5对输入图片根据尺寸大小做动态调整,每张原始输入会变换成多张新图,每张新图占用96个tokens,而不是像Qwen-VL统一的把各种不同尺寸的图片全部插值为448*448(scale_resolution*scale_resolution)的大小然后压缩为固定的256个token。其处理流程如下:

1、如果原始图片面积小于448*448(scale_resolution*scale_resolution),按照如下规则进行上采样,保持采样后的图片高宽比和原始图片的高宽比保持不变,并且长宽要能整除patch_size(=14),得到原始图片扩大的图source_upsample_image,把source_upsample_image为输入图片的中间表达输入到模型中进行处理。

2、如果原始图片面积大于448*448(scale_resolution*scale_resolution),按照上述同样的规则进行下采样,得到原始图片缩小的图source_downsample_image;为了保留更多的img信息,按照一定的规律,再次对原始图片进行变换(主要是扩大),然后从变换后的img中完整的分割出一定数量的patches子图,把这些patches子图和source_downsample_image一起作为输入图片的中间表达输入到模型中进行处理,source_downsample_image和patche子图可能尺寸不一样,需要补充padding。切分patches需要遵循一下原则(具体的实现需要去看模型源码):

         a、原始图片变换后,能完整的切成m个patches,其中2<=m<=9,m=x*y,x>=1,y>=1(x和y表示子图的行列布局);并且x/y的比例和原始图片的高宽之比变化最小;

         b、patch子图高和宽都能被patch_size(=14)整除,并且patch的子图heigth = round(scale_resolution * height/w )

以上的设计我猜想是为了尽可能的保留原始输入图片的信息,扩大后的图片不要扭曲高宽比几乎不变,有尽量少占用token数,注意一张输入图片经过上述梳理后变为1+n个小图,每个小图最终都压缩为96个token。

为了更加清晰的看清楚整个text_img和img_融合过程,同样看一下输入后的中间产物 

文本经过tokenizer后的产物

图片占位符

prompt:请识别图片中的全部文本

img:500*500

最终输入到tokenizer中的文本就是下图的图片占位符+中文prompt组成的

img则被处理为448*448的source_downsample_image和2张630*322的patch

同样的输入到模型后,中间变量的维度变化如下:

推理代码

简单的给出一个推理代码,如下

QwenVL

图片分别是:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoModelForCausalLM, AutoTokenizer
import base64

def img2base64(file_name):
    with open(file_name, 'rb') as f:
        encoded_string = base64.b64encode(f.read())
        return encoded_string

if __name__ == '__main__':

    model_path = "/AI_TEAM/yanghuang/pretrain_models/torch/Qwen-VL-Chat"
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", bf16=True,
                                            trust_remote_code=True).eval()

    img_path = "/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/test.jpg"
    prompt = "请输出图片中的文字"
    img_base64_list = []
    # 第一轮对话
    query = tokenizer.from_list_format([
        {'image': img_path},  # Either a local path or an url
        {'text': f'{prompt}'},
    ])
    print(f"query---: {[query]}")
    response, history = model.chat(tokenizer, query=query, history=None)
    print("response---",[response])
    print("history---", [history])
    print("*"*100)

    query = tokenizer.from_list_format([
        {'image': "/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/image_with_text.jpg"},  # Either a local path or an url
        {'text': f'好的,那新的图片中是什么内容'},
    ])
    print(f"query---: {[query]}")
    response, history = model.chat(tokenizer, query=query, history=history)
    print("response---", [response])
    print("history---", [history])

结果如下

MiniCPM-Llama3-V-2_5

图片:

代码:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import json
from PIL import Image
import base64
import io
from transformers import AutoTokenizer, AutoModel
from peft import AutoPeftModelForCausalLM


class MiniCPMV2_5:
    def __init__(self, model_path, adapter_path=None) -> None:
        if adapter_path:
            self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_path, trust_remote_code=True).to(dtype=torch.bfloat16)
            vpm_resampler_embedtokens_weight = torch.load(f"{adapter_path}/vpm_resampler_embedtokens.pt")
            self.model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False)
        else:
            self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
        self.model.eval().cuda()
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    def chat(self, input, img_base64_list=None):
        try:
            image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
        except Exception as e:
            return "Image decode error"

        msgs = json.loads(input['question'])

        answer = self.model.chat(
            image=image,
            msgs=msgs,
            tokenizer=self.tokenizer,
            sampling=True,
            temperature=0.7
        )
        return answer

def img2base64(file_name):
    with open(file_name, 'rb') as f:
        encoded_string = base64.b64encode(f.read())
        return encoded_string

if __name__ == '__main__':
    model_path = '/AI_TEAM/yanghuang/pretrain_models/torch/MiniCPM-Llama3-V-2_5'
    adapter_path = "/AI_TEAM/yanghuang/workspace/project/MiniCPM-V/output/solid_bg_imgs_20240621/checkpoint-72000"
    model = MiniCPMV2_5(model_path=model_path)

    img_path = "./500_500_test.jpg"
    im_64 = img2base64(img_path)
    msgs = [{"role": "user", "content": "请识别图片中的全部文本"}]
    inputs = {"image": im_64, "question": json.dumps(msgs)}
    response = model.chat(inputs)
    print(response)

三、QwenVL和MiniCPM-Llama3-V-2_5功能测试和效果

又到激动人心的实验测试环节了,这里我们主要是关注图片的OCR结果,为此,我这边生成了纯色背景的各色图片。图片生成代码如下:

from PIL import Image, ImageDraw, ImageFont
import glob
import os
import pandas as pd
import random
import uuid
import json
import tqdm
import math
random.seed(100)


def do_create(font_paths, font_sizes, positions, texts, img_color, text_colors,  width, height, save_path):
    image = Image.new('RGB', (width, height), img_color)
    # 创建一个可以在图片上写字的对象
    draw = ImageDraw.Draw(image)
    for font_path, position, text, text_color, font_size in zip(font_paths, positions, texts, text_colors, font_sizes):
        font = ImageFont.truetype(font_path, font_size)
        draw.text(position, text, font=font, fill=text_color)
    image.save(save_path)


def get_text_and_img_color(white_img_p = 0.75):
    # 图片颜色
    if random.random() < white_img_p:
        # 纯白色
        img_color = (255, 255, 255)
        text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
    else:
        img_color = (random.choice(range(256)), random.choice(range(256)), random.choice(range(256)))
        text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
        while text_color == img_color:
            text_color = (random.choice(range(255)), random.choice(range(255)), random.choice(range(255)))
            if text_color != img_color:
                break
    return img_color, text_color



def compute_single_row_img_infos(text, width_max = 1344, height_max = 1344):
    font_size_list = list(range(10, 25))

    # 图片颜色
    img_color, text_color = get_text_and_img_color(white_img_p = 0.75)

    # 字体大小和图片页面布局
    ft_size = random.choice(font_size_list)

    # text 不做切断
    height = random.choice(range(200, 500, 10))
    width = ft_size * len(text) + random.choice(range(50, 100))
    # 0.4的概率截断左右两边的字
    if random.random() < 0.4:
        if int(len(text) * 0.5) > 1:
            candidate_cut = random.choice(range(1, int(len(text) * 0.5)))
            if random.random() < 0.5:
                text = text[0:-candidate_cut]
            else:
                text = text[candidate_cut:]
    assert width < width_max, "width > width_max"
    assert len(text) < 2048
    pos_x = random.choice(range(2, width - ft_size * len(text)))
    pos_y = random.choice(range(0, height - ft_size))
    position = (pos_x, pos_y)

    return text, ft_size, width, height, position, img_color, text_color


def create_single_row_words_img():
    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/single_row"
    if not os.path.exists(path):
        os.makedirs(path)
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")

    with open('./one_row_content.jsonl','r', encoding='utf-8') as f:
        lines = f.readlines()

    total_count = 0
    with open('./solid_bg_imgs/single_row.json', 'w', encoding='utf-8') as f:
        for line in tqdm.tqdm(lines[0:], desc="create_single_row_words_img"):
            try:
                text = json.loads(line)['content']
                ft_ps = [random.choice(font_paths)]
                text, ft_size, width, height, position, img_color, text_color = compute_single_row_img_infos(text)
                save_path = os.path.join(path, f"{total_count}.jpg")
                ft_sizes = [ft_size]
                positions = [position]
                texts = [text]
                text_colors = [text_color]
                do_create(ft_ps, ft_sizes, positions, texts, img_color, text_colors, width, height, save_path)
                temp = {
                    "context": text,
                    "path": save_path
                }
                f.write(json.dumps(temp, ensure_ascii=False) + '\n')
                total_count += 1
            except Exception as e:
                print(e)
    print(f"create_single_row_words_img count {total_count}")

def multi_0_thousand_word():

    def compute_imgs_text_label(ft_se, poisitons, texts, width):
        label = ""
        new_texts = []
        for ft, poi, text in zip(ft_se, poisitons, texts):
            # 计算字的个数
            word_count = int((width - poi[0]) / ft)
            text = text[0:word_count]
            new_texts.append(text)
            label += text + "\n\n"

        label = label.strip('\n\n')
        return label, texts

    width = 500
    height = 500
    font_size_list = list(range(15, 25))
    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/thousand_word"
    if not os.path.exists(path):
        os.makedirs(path)
    df = pd.read_csv("./classification_train_dataset_75W_1130.tsv", sep='\t')
    sentences1 = df['sentence1'].values.tolist()
    sentences2 = df['sentence2'].values.tolist()
    sentences = sentences1 + sentences2
    sentences = [ str(ele) for ele in sentences]
    sentences = list(set(sentences))
    sentences.sort()
    print("len(sentences)", len(sentences))

    targets = []
    text_path = "/data02/yanghuang/workspace/Qwen-VL/datas/海尔热线报装报修_20240327_179672.jsonl"
    with open(text_path, 'r',
              encoding='utf-8') as reader:
        lines = reader.readlines()
        for line in lines:
            target = json.loads(line)['target']
            targets.append(target)
    targets = list(set(targets))
    targets.sort()
    print(f"len(targets) {len(targets)}")

    # font_paths = glob.glob("/AI_TEAM/yanghuang/workspace/project/Qwen-VL/datas/fonts/*")
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")
    rows = list(range(2, 7))
    row_spaces = list(range(20, 40))
    with open('./solid_bg_imgs/multi_rows/thousand_word.json', 'w', encoding='utf-8') as f:
        total_count = 0
        for _ in tqdm.tqdm(range(25000), desc="create_multi_rows_words_img thousand_word"):
            row_count = random.choice(rows)
            # 不放回采样
            ft_ps = random.sample(font_paths, k=row_count)
            # 放回采样
            ft_se = random.choices(font_size_list, k=row_count)
            texts = random.sample(sentences, k=row_count)
            poisitons = []
            first_row_height = 10
            for i in range(row_count):
                if i == 0:
                    poisitons.append((random.choice(range(0, 200)), first_row_height))
                else:
                    poisitons.append(
                        (random.choice(range(0, 200)), poisitons[i - 1][1] + ft_se[i - 1] + random.choice(row_spaces)))

            label, texts = compute_imgs_text_label(ft_se, poisitons, texts, width)
            save_path = os.path.join(path, f"{total_count}.jpg")
            img_color = (255, 255, 255)
            text_colors = [(random.choice(range(255)), random.choice(range(255)), random.choice(range(255))) for _ in
                           range(row_count)]
            temp = {
                "context": label,
                "path": save_path,
                "textsource": "thousand_word"
            }
            f.write(json.dumps(temp, ensure_ascii=False) + '\n')

            do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                      height=height, save_path=save_path)
            total_count +=1

        # 海尔热线报装报修的生成
        path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/海尔热线报装报修"
        if not os.path.exists(path):
            os.makedirs(path)
        for index in tqdm.tqdm(range(25000), desc=f"create_multi_rows_words_img 海尔热线报装报修"):
            target = targets[index].split('\n\n')
            row_count = len(target)
            # 不放回采样
            ft_ps = [random.choice(font_paths)] * row_count

            # 放回采样
            ft_se = [random.choice(font_size_list)] * row_count
            texts = target
            poisitons = []
            first_row_height = 10
            posi_x = random.choice(range(0, 200))
            space = random.choice(row_spaces)
            for i in range(row_count):
                if i == 0:
                    poisitons.append((posi_x, first_row_height))
                else:
                    poisitons.append(
                        (random.choice(range(0, 200)), poisitons[i - 1][1] + ft_se[i - 1] + random.choice(row_spaces)))

            label, texts = compute_imgs_text_label(ft_se, poisitons, texts, width)
            save_path = os.path.join(path, f"{total_count}.jpg")
            img_color = (255, 255, 255)
            text_colors = [(random.choice(range(255)), random.choice(range(255)), random.choice(range(255))) for _ in
                           range(row_count)]
            temp = {
                "context": label,
                "path": save_path,
                "textsource": "海尔热线报装报修"
            }
            f.write(json.dumps(temp, ensure_ascii=False) + '\n')
            do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                      height=height, save_path=save_path)
            total_count += 1


def multi_1_poet():
    width_max = 600
    height_max = 800
    font_size_list = list(range(10, 35))
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")

    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/poet"
    if not os.path.exists(path):
        os.makedirs(path)

    poets_map = {}
    poets_index = []
    poet_paths = glob.glob("/data02/yanghuang/datasets/LLM_datasets/chinese-poetry/全唐诗/简体/*.json")
    index = 0
    total = 0
    for poet_path in tqdm.tqdm(poet_paths, desc="get_poets"):
        with open(poet_path, 'r', encoding='utf-8') as f:
            poets = json.load(fp=f)
        for poet in poets:
            paragraphs = poet['paragraphs']
            total += 1
            if 2 <= len(paragraphs) <= 4 and poet['title'] != "":
                poets_map[index] = poet
                poets_index.append(index)
                index += 1
    print(total)
    print(len(poets_index))
    def compute_poets_img_infos(poets, font_size_list, font_paths, width, height):
        #行间距
        ratio_sinle = random.choice([1.0,1.1,1.2,1.3,1.4,1.5])
        ratio_multi = random.choice([3,4,5])
        ft_ps = []
        ft_se = []
        poisitons = []
        texts = []
        text_colors = []
        label = ""
        if random.random() < 0.6:
            # 每首诗 字体、颜色、布局统一
            ft_p = random.choice(font_paths)
            ft_s = random.choice(font_size_list)
            # 图片颜色
            img_color, text_color = get_text_and_img_color(white_img_p=0.75)

            counts = [ len(poet['paragraphs'][0]) for poet in poets]
            counts.extend([len(poet['title'])+len(poet['author'])+len("——") for poet in poets])

            word_count_max = max(counts)
            posi_x = random.choice(range(0,  width - word_count_max * ft_s-10))
            for index, poet in enumerate(poets):
                author = poet['author']
                paragraphs = poet['paragraphs']
                title = poet['title']
                text = f"{title}——{author}"
                if index == 0:
                    position = (posi_x, 10)
                else:
                    position = (posi_x, poisitons[-1][1] + ft_s*ratio_multi)
                ft_ps.append(ft_p)
                ft_se.append(ft_s)
                poisitons.append(position)
                texts.append(text)
                text_colors.append(text_color)
                label += text + '\n'
                for sen in paragraphs:
                    text = sen
                    position = (posi_x, poisitons[-1][1]+int(ft_s*ratio_sinle))
                    ft_ps.append(ft_p)
                    ft_se.append(ft_s)
                    poisitons.append(position)
                    text_colors.append(text_color)
                    texts.append(text)
                    label += text + '\n'
                label = label.strip('\n') + '\n\n'
            label = label.strip("\n\n")
        else:
            for index, poet in enumerate(poets):
                author = poet['author']
                paragraphs = poet['paragraphs']
                title = poet['title']
                text = f"{title}——{author}"
                # 字体和尺寸
                ft_p = random.choice(font_paths)
                ft_s = random.choice(font_size_list)
                # 图片颜色
                img_color, text_color = get_text_and_img_color(white_img_p=0.75)
                word_count_max = max([len(poet['paragraphs'][0]), len(poet['title'])+len(poet['author'])+len("——")])
                posi_x = random.choice(range(0, width - word_count_max * ft_s - 10))

                if index == 0:
                    position = (posi_x, 10)
                else:
                    position = (posi_x, poisitons[-1][1] + ft_se[-1]*ratio_multi)
                ft_ps.append(ft_p)
                ft_se.append(ft_s)
                poisitons.append(position)
                texts.append(text)
                text_colors.append(text_color)
                label += text + '\n'
                for sen in paragraphs:
                    text = sen
                    position = (posi_x, poisitons[-1][1]+ int(ft_s*ratio_sinle))
                    ft_ps.append(ft_p)
                    ft_se.append(ft_s)
                    poisitons.append(position)
                    text_colors.append(text_color)
                    texts.append(text)
                    label += text + '\n'
                label = label.strip('\n') + '\n\n'
            label = label.strip("\n\n")

        assert poisitons[-1][1] < height, "beyond img height"
        height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20,150))
        return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height

    total_count = 0
    with open('./solid_bg_imgs/multi_rows/poet.json', 'w', encoding='utf-8') as f:
        while len(poets_index) > 0:
            if random.random() < 0.3:
                choiced_indexs = random.sample(poets_index, k = 1)
            elif random.random() <0.6:
                choiced_indexs = random.sample(poets_index, k = 2 if len(poets_index) >= 2 else 1)
            else:
                choiced_indexs = random.sample(poets_index, k = 3 if len(poets_index) >= 3 else len(poets_index))

            poets_index = list(set(poets_index) - set(choiced_indexs))
            poets_index.sort()

            try:
                poets = [poets_map[ele] for ele in choiced_indexs]
                ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_poets_img_infos(poets,
                                                                                                        font_size_list,
                                                                                                        font_paths,
                                                                                                        width_max, height_max)
                save_path = os.path.join(path, f"{total_count}.jpg")
                do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                          height=height, save_path=save_path)

                total_count += 1
                temp = {
                    "context": label,
                    "path": save_path,
                    "textsource": "poet"
                }
                print(f"\rtotal_count: {total_count}", end="")
                f.write(json.dumps(temp, ensure_ascii=False) + '\n')
            except Exception as e:
                continue

            if len(poets_index) == 0:
                break
    print("")
    print(f"total_count: {total_count}")


def multi_2_dialog():
    paths = [
        "/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/达能_20240403_80508_2_aug_20240403_1414_80508.jsonl",
        "/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/美素工单小结_20240329_2687_aug_20240329_1710_2687.jsonl",
        "/data02/yanghuang/workspace/llm_platform/trainer/dataset/factor/太平寿险工单总结_20240605_60422_aug_20240605_1129_60422.jsonl",
        "/data02/yanghuang/workspace/llm_platform/trainer/dataset/general/圆通总结摘要_20240614_14036.jsonl",
    ]
    dialogs = []
    for path in paths:
        with open(path, 'r', encoding='utf-8') as f:
            datas = f.readlines()
            print(f"{path}--{len(datas)}")
            for data in datas:
                dialog = json.loads(data)['context'].split("\n\n")[0]
                if 2048 >= len(dialog) >= 1 and dialog != "":
                    dialogs.append(dialog)
    print(f"len(dialogs) {len(dialogs)}")
    dialogs = list(set(dialogs))
    dialogs.sort()
    print(f"len(dialogs) {len(dialogs)}")
    width_max = 1300
    height_max = 1300
    font_size_list = list(range(10, 35))
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")

    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/business_dialog"
    if not os.path.exists(path):
        os.makedirs(path)


    def compute_dialog_img_infos(dialog, font_size_list, font_paths, width, height):
        ratio_sinle = random.choice([1.0, 1.1, 1.2, 1.3, 1.4, 1.5])

        ft_ps = []
        ft_se = []
        poisitons = []
        texts = []
        text_colors = []
        label = ""

        dialog = dialog.split('\n')

        # 每首诗 字体、颜色、布局统一
        ft_p = random.choice(font_paths)
        ft_s = random.choice(font_size_list)
        # 图片颜色
        img_color, text_color = get_text_and_img_color(white_img_p=0.75)
        posi_x = random.choice(range(10, 50))

        row_count = 1100//int(ft_s*ratio_sinle)
        for index, text in enumerate(dialog[0:row_count]):
            ft_ps.append(ft_p)
            ft_se.append(ft_s)
            text_colors.append(text_color)
            if index == 0:
                position = (posi_x, 10)
            else:
                position = (posi_x, poisitons[-1][1] + int(ft_s*ratio_sinle))

            poisitons.append(position)
            text = text[: int((width - posi_x)/ft_s)]
            assert width >= posi_x + len(text) * ft_s ,"sentence too long"
            texts.append(text)
            label += text +'\n'

        assert poisitons[-1][1] < height, "turns too long"
        height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20, 100))

        return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height


    total_count = 0
    with open('./solid_bg_imgs/multi_rows/business_dialog.json', 'w', encoding='utf-8') as f:
        for dialog in tqdm.tqdm(dialogs[0:], desc="multi_2_dialog"):
            try:
                ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_dialog_img_infos(
                    dialog, font_size_list, font_paths, width_max, height_max)
                save_path = os.path.join(path, f"{total_count}.jpg")
                do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                          height=height, save_path=save_path)
                total_count += 1
                temp = {
                    "context": label,
                    "path": save_path,
                    "textsource": "business_dialog"
                }
                f.write(json.dumps(temp, ensure_ascii=False) + '\n')
            except Exception as e:
                print(e)
                continue
        print(f"\ntotal_count: {total_count}")

def compute_multi_paragraph_img_infos(content, font_size_list, font_paths, width_max, height_max):
    # 行间距
    ratio_sinle = random.choice([1.0, 1.1, 1.2, 1.3, 1.4, 1.5])
    ratio_multi = random.choice([3, 4, 5])

    token_size_max = int(math.sqrt(width_max*height_max/2.1/len(content)))

    assert token_size_max > font_size_list[0], "font size too small"

    ft_s = random.choice(font_size_list)
    while ft_s > token_size_max:
        ft_s = random.choice(font_size_list)
        if ft_s <= token_size_max:
            break

    ft_ps = []
    ft_se = []
    poisitons = []
    texts = []
    text_colors = []
    label = ""

    row_index = 0

    if random.random() < 0.65:
        # 字体种类、颜色统一
        ft_p = random.choice(font_paths)
        img_color, text_color = get_text_and_img_color(white_img_p=0.75)
        temps = content.split('\n\n')
        for temp in temps:
            temp = temp.split('\n')
            for ele in temp:
                each_row_token_count = (width_max - 10-2*ft_s) // ft_s
                rows = 1 + (len(ele)+2)//each_row_token_count
                new_paragraph = True
                for row in range(rows):
                    start = row * each_row_token_count
                    end = (row + 1) * each_row_token_count
                    text = ele[start:end]
                    label += text + "\n"
                    if row_index == 0:
                        position = (10+2*ft_s, 10)
                    else:
                        if new_paragraph:
                            position = (10+2*ft_s, poisitons[-1][1] + ft_s * ratio_multi)
                        else:
                            position = (10, poisitons[-1][1] + int(ft_s * ratio_sinle))

                    row_index += 1

                    ft_se.append(ft_s)
                    ft_ps.append(ft_p)
                    text_colors.append(text_color)
                    texts.append(text)
                    poisitons.append(position)
                    new_paragraph = False
                label = label.strip("\n") + "\n\n"
        label = label.strip('\n\n')
    else:
        temps = content.split('\n\n')
        for temp in temps:
            temp = temp.split('\n')
            for ele in temp:
                #每个段落一种字体和颜色
                ft_p = random.choice(font_paths)
                img_color, text_color = get_text_and_img_color(white_img_p=0.75)
                each_row_token_count = (width_max - 10-2*ft_s) // ft_s
                rows = 1 + (len(ele)+2)//each_row_token_count
                new_paragraph = True
                for row in range(rows):
                    start = row * each_row_token_count
                    end = (row + 1) * each_row_token_count
                    text = ele[start:end]
                    label += text + "\n"
                    if row_index == 0:
                        position = (10+2*ft_s, 10)
                    else:
                        if new_paragraph:
                            position = (10+2*ft_s, poisitons[-1][1] + ft_s * ratio_multi)
                        else:
                            position = (10, poisitons[-1][1] + int(ft_s * ratio_sinle))
                    row_index += 1
                    ft_se.append(ft_s)
                    ft_ps.append(ft_p)
                    text_colors.append(text_color)
                    texts.append(text)
                    poisitons.append(position)
                    new_paragraph = False
                label = label.strip("\n") + "\n\n"
        label = label.strip('\n\n')
    assert poisitons[-1][1] < height_max, "beyond img height"
    height = poisitons[-1][1] + ft_se[-1] + random.choice(range(20, 150))
    return ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width_max, height




def multi_3_gaokao_comprehension():
    width_max = 1300
    height_max = 1300
    font_size_list = list(range(5, 35))
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")


    paths = glob.glob('/data02/yanghuang/datasets/LLM_datasets/VGaokao-阅读理解/data/*/*.json')
    contents = []
    for path in paths:
        with open(path, 'r', encoding='utf-8') as f:
            datas = json.load(fp=f)['data']
            for data in datas:
                if "context" in data:
                    content = data['context']
                    if 100 <= len(content) <= 2048:
                        contents.append(content)

    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/gaokao_comprehension"
    if not os.path.exists(path):
        os.makedirs(path)

    total_count = 0
    with open('./solid_bg_imgs/multi_rows/gaokao_comprehension.json', 'w', encoding='utf-8') as f:
        for dialog in tqdm.tqdm(contents[0:], desc="multi_3_gaokao_comprehension",ncols=90):
            try:
                ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_multi_paragraph_img_infos(
                    dialog, font_size_list, font_paths, width_max, height_max)
                save_path = os.path.join(path, f"{total_count}.jpg")
                do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                          height=height, save_path=save_path)
                total_count += 1
                temp = {
                    "context": label,
                    "path": save_path,
                    "textsource": "gaokao_comprehension"
                }
                f.write(json.dumps(temp, ensure_ascii=False) + '\n')
            except Exception as e:
                print(e)
                continue
        print(f"\ntotal_count: {total_count}")


def multi_4_glm_common():
    contents = []
    general_paths = glob.glob("/data02/yanghuang/workspace/llm_platform/trainer/dataset/general/glm2*.jsonl")
    for path in general_paths:
        with open(path,'r',encoding='utf-8') as f:
            lines = f.readlines()
        desc = path.split('/')[-1].replace(".jsonl","")
        for line in tqdm.tqdm(lines, desc=f"{desc}"):
            line = json.loads(line)
            content = line['context'] + '\n'+ line['target']
            if 100 <= len(content) <= 2048:
                contents.append(content)


    contents = random.sample(contents,k=650000)

    path = f"/data02/yanghuang/workspace/MiniCPM-V/datas/solid_bg_imgs/multi_rows/glm_common"
    if not os.path.exists(path):
        os.makedirs(path)

    width_max = 1300
    height_max = 1300
    font_size_list = list(range(5, 35))
    font_paths = glob.glob("/data02/yanghuang/workspace/Qwen-VL/datas/fonts/*")

    total_count = 0
    with open('./solid_bg_imgs/multi_rows/glm_common.json', 'w', encoding='utf-8') as f:
        for dialog in tqdm.tqdm(contents[0:], desc="glm_common", ncols=90):
            try:
                ft_ps, ft_se, poisitons, texts, img_color, text_colors, label, width, height = compute_multi_paragraph_img_infos(
                    dialog, font_size_list, font_paths, width_max, height_max)
                save_path = os.path.join(path, f"{total_count}.jpg")
                do_create(ft_ps, ft_se, poisitons, texts, img_color=img_color, text_colors=text_colors, width=width,
                          height=height, save_path=save_path)
                total_count += 1
                temp = {
                    "context": label,
                    "path": save_path,
                    "textsource": "glm_common"
                }
                f.write(json.dumps(temp, ensure_ascii=False) + '\n')
            except Exception as e:
                print(e)
                continue
        print(f"\ntotal_count: {total_count}")



def create_multi_rows_words_img():
    multi_0_thousand_word()
    multi_1_poet()
    multi_2_dialog()
    multi_3_gaokao_comprehension()
    multi_4_glm_common()


def main():
    create_single_row_words_img()
    create_multi_rows_words_img()


if __name__ == '__main__':
    main()

主要是针对不同的文本设计合适的版面来生成多行或者单行的图片(上面的代码中一些文本文件和字体文件需要替换为自己本地路径才能运行成功),示例如下:

单行文本图片

多行文本图片1

多行文本图片2

对上述生成的图片使用没有微调的QwenVL和MiniCPM-Llama3-V-2_5进行推理,计算样本准确率,样本字符编辑距离,准确率、错误率等。

import Levenshtein
def wer_ccor_compute(ref,pre):
    substitution = 0
    deletion = 0
    insertion = 0
    results = Levenshtein.editops(ref, pre)
    for r in results:
        if "replace" in r:
            substitution += 1
        elif "delete" in r:
            deletion += 1
        else:
            insertion += 1

    wer = (deletion+insertion+insertion) / len(ref)
    ccor = (len(ref)-deletion-substitution)/len(ref)
    ld = substitution + deletion + insertion
    return substitution, deletion, insertion, ld, wer, ccor

上述代码使用Levenshtein库来计算ref和pre字符串的差异,替换、删除、插入的数量,wer、ccor和ld(编辑距离越小越好)

未微调的效果

微调后的效果

可以看到微调前两个模型在纯色背景图片上的识别效果不太好,说明OCR能力有待提高;MiniCPM-Llama3-V-2_5的效果微调前或者后都要由于QwenVL。微调后MiniCPM-Llama3-V-2_5样本准确率从0.419提升到0.81,QwenVL从0.15提升到0.75;多行文本的图片角度来看,MiniCPM-Llama3-V-2_5从0.0125提升到0.63,提升巨大,QwenVL多行文本图片准确率也从0.0095提升到了0.3845,但是要远低于前者。

demo展示

前端页面vue3实现,后端采用aiohttp实现流式推理,后端代码:

import os
import pytomlpp as toml
config = toml.load('config.toml')
os.environ['CUDA_VISIBLE_DEVICES'] = config['common']['device']
import asyncio
from aiohttp import web
import time
import socket
import logging
import pandas as pd
import json
from aiohttp_cors import setup, ResourceOptions


class LLMVLInfer(object):
    def __init__(self, config):
        self.model_type = config['model']['model_type']
        self.model_path = config['model']['model_path']
        self.adapter_path = config['model']['adapter_path']
        self.model = globals()[self.model_type](self.model_path,  self.adapter_path)
        self.logger = self.create_logger()
        self.df = pd.read_csv(config['data']['test_file'])

    def inference(self, prompt, img_base64_list):
        response = self.model.chat(prompt, img_base64_list=img_base64_list)
        return response

    async def stream_generator(self,prompt, img_base64_list):
        stream_gen = self.model.chat(prompt, img_base64_list=img_base64_list, stream=True)
        for ele in stream_gen:
            yield ele

    def component_prompt(self, content, img):
        img_base64_list = None
        if self.model_type == "MiniCPMV2_5":
            msgs = [{"role": "user", "content": content}]
            prompt =  {"image": img, "question": json.dumps(msgs)}
        else:
            img_path = ""
            prompt = self.model.tokenizer.from_list_format([
                {'image': img_path},  # Either a local path or an url
                {'text': f'{content}'},
            ])
            img_base64_list = [img]
        return prompt, img_base64_list


    async def post(self, request:web.Request):
        req = await request.json()
        id = req['id']
        img = req['params']['data']['image']
        content = req['params']['data']['prompt']
        prompt, img_base64_list = self.component_prompt(content, img)
        start = time.time()
        try:
            # result = await asyncio.get_event_loop().run_in_executor(None, self.inference, prompt)
            result = await asyncio.to_thread(self.inference, prompt, img_base64_list)
            end = time.time()
        except Exception as e:
            self.logger.info(f'id: {id} inference fail: {e}')
            return web.json_response(self.build_fail_resp(id_=id, code=-1, msg=f"{e}"))
        tokens = len(self.model.tokenizer.encode(result))
        cost_time = (end-start)*1000
        speed = tokens/(end-start)
        send_data = self.build_resp_success(id, result, tokens, cost_time, speed)
        self.logger.info(json.dumps(send_data,ensure_ascii=False))
        return web.json_response(send_data)


    async def get_current_page_datas(self,request:web.Request):
        req = await request.json()
        self.logger.info(f"{req}")
        current_page = req['params']['currentPage']
        page_size = req['params']['pageSize']
        start = (current_page -1) * page_size
        end = current_page*page_size
        df = self.df[start:end]
        items = []
        for _, row in df.iterrows():
            temp = {'imgurl': None, 'imgpath': row['path'], 'percentage': 0, 'result': [row['输入文本'], row['识别文本']],
                    'levenshtein_distance': row['levenshtein_distance'], 'isright': row['是否正确'],"imgshow":False}
            items.append(temp)
        result = {
            "itemtotal":len(self.df),
            "items":items
        }
        self.logger.info(f"send current page [page-{current_page}] datas {result}")
        return web.json_response({
            "status":0,
            "result":result
        })


    async def get_img(self, request:web.Request):
        req = await request.json()
        self.logger.info(f"{req}")
        imgpath = req['params']['imgpath']
        with open(imgpath, 'rb') as f:
            image_data = f.read()
        headers = {'Content-Type': 'image/jpeg'}  # 根据图片类型更改这里的 MIME 类型
        return web.Response(body=image_data, headers=headers)

    async def getmertics(self, request:web.Request):
        df = self.df
        acc = len(df[df['是否正确'] == True]) / len(df)
        total_ld_count = 0
        total_word_count = 0
        deletions = 0
        substitutions = 0
        for _, row in df.iterrows():
            total_word_count += len(row['输入文本'])
            total_ld_count += int(row['levenshtein_distance'])
            deletions +=int(row['D'])
            substitutions += int(row['S'])
        wer = total_ld_count / total_word_count
        war = (total_word_count-deletions-substitutions)/total_word_count
        result = {
            "acc":acc,
            "wer":wer,
            "war":war
        }
        self.logger.info(f"{result}")
        return web.json_response(result)

    async def stream_infer(self, request:web.Request):
        ws = web.WebSocketResponse()
        await ws.prepare(request)
        # 接收客户端的消息
        async for msg in ws:
            if msg.type == web.WSMsgType.TEXT:
                data = msg.json()
                self.logger.info(f"ws request: {data}")
                # 解析文本和图片的Base64编码
                content = data.get('text', '')
                img = data.get('image_base64_str', '')
                prompt, img_base64_list = self.component_prompt(content, img)
                try:
                    stream_generator = self.stream_generator(prompt, img_base64_list)
                    async for token in stream_generator:
                        if token != "":
                            await ws.send_str(token)
                finally:
                    self.logger.info(f"ws close")
                    await ws.close()
        return ws

    async def handle(self, request:web.Request):
        return web.FileResponse('./dist/index.html')

    def build_fail_resp(self, id_: int, code: int, msg: str):
        return web.json_response({
            'id': id_,
            'jsonrpc': '2.0',
            'ret': code,
            'result': {
                "error_info": msg
            }
        })


    def build_resp_success(self, id, answer, tokens, cost_time, speed):
        rsp = {
            "id": id,
            "jsonrpc": "2.0",
            "ret": 0,
            "result": {
                "chatInfo": {
                    "answer": answer,
                    "elements": []
                },
                "tokens": tokens,
                "cost_time": f"{cost_time} ms",
                "speed": f"{speed} tokens/s"
            }
        }
        return rsp

    def create_logger(self):
        log_level = config["log"]["log_level"]
        log_path = "./logs/server.log"
        logger = logging.getLogger(__name__)
        logger.setLevel(level=log_level)

        formatter = logging.Formatter("%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s")

        # 创建一个handler,用于写入日志文件,按大小覆盖
        # file_handler = logging.handlers.RotatingFileHandler(filename=log_path, maxBytes=838860800, backupCount=20, encoding='utf-8')
        # 按日期覆盖
        file_handler = logging.handlers.TimedRotatingFileHandler(filename=log_path, when='D', interval=1,
                                                                 encoding='utf-8')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(level=log_level)
        logger.addHandler(file_handler)

        # 创建一个handler,用于将日志输出到控制台
        console = logging.StreamHandler()
        console.setLevel(level=log_level)
        console.setFormatter(formatter)
        logger.addHandler(console)

        return logger



async def init_app():
    llmvl_infer = LLMVLInfer(config)
    app = web.Application()

    app.add_routes([
        web.post('/nlp',llmvl_infer.post),
        web.post('/CurrentPageDatas', llmvl_infer.get_current_page_datas),
        web.post('/fetchimg', llmvl_infer.get_img),
        web.get('/getmertics', llmvl_infer.getmertics),
        web.get("/ws",llmvl_infer.stream_infer),
        web.get("/", llmvl_infer.handle),
        web.static('/', path='./dist/', name='static')
    ])
    cors = setup(app, defaults={
        "*": ResourceOptions(
            allow_credentials=True,
            expose_headers="*",
            allow_headers="*",
        )
    })
    for route in list(app.router.routes()):
        cors.add(route)

    return app

if __name__ == '__main__':
    if not os.path.exists("./logs"):
        os.makedirs("./logs")
    bind_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0)
    bind_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    bind_socket.bind(('0.0.0.0', 2222))
    web.run_app(init_app(), sock=bind_socket)

页面截图如下:

四、思考

MiniCPM-Llama3-V-2_5参数量比QwenVL的参数量还要少一点,而效果再实际体验中却好很多,我认为预训练数据的影响最大,猜测MiniCPM-Llama3-V-2_5训练的图片和文本数据更多质量更高;另外一个就是MiniCPM-Llama3-V-2_5对输入图片的前置处理和压缩方式保留的图片信息更多,img占用的token数量也更多,这个也是效果比较好的原因之一;缺陷呢我觉得就是MiniCPM-Llama3-V-2_5对图片tokenize占用过多的token,处理多图甚至多轮多图的能力天然的并不如QwenVL系列那么灵活和低成本。

还有一个就是QwenVL的缺陷,它把图片路径和文本prompt拼接一起输入,然后再解码出来,也不是解耦的那种方式,对模型部署的性能有一定的影响。客户请求传入的是一个img,QwenVL需要把图片保存在本地得到一个img_path,才能进行推理。推理过程中有再读了一次图片,明显这样的设计不合理。需要优化,可以直接把图片地址id虚拟化,同时把图片和id映射后直接输入图片进去,后端不保存img获取img_path,而是根据虚拟化的图片地址id和id映射字典直接获取图片进行推理。

参考文章

minicpm-llama3-v-25

Qwen-VL

Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond

多模态大模型:视觉模型与LLM的结合之路(四)

多模态大模型QwenVL

  • 11
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值