山东大学创新实训第九周周报------VisualGLM-6B后端接口实现

在VisualGLM-6B的开源代码中,给出了一个网页版的demo以及对应的代码。THUDM/VisualGLM-6B: Chinese and English multimodal conversational language model | 多模态中英双语对话语言模型 (github.com)

#!/usr/bin/env python

import gradio as gr
from PIL import Image
import os
import json
from model import is_chinese, get_infer_setting, generate_input, chat
import torch

def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
    input_para = {
        "max_length": 2048,
        "min_length": 50,
        "temperature": 0.8,
        "top_p": 0.4,
        "top_k": 100,
        "repetition_penalty": 1.2
    }
    input_para.update(request_data)

    input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
    input_image, gen_kwargs =  input_data['input_image'], input_data['gen_kwargs']
    with torch.no_grad():
        answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
                            max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
                            top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
    return answer


def request_model(input_text, temperature, top_p, image_prompt, result_previous):
    result_text = [(ele[0], ele[1]) for ele in result_previous]
    for i in range(len(result_text)-1, -1, -1):
        if result_text[i][0] == "" or result_text[i][1] == "":
            del result_text[i]
    print(f"history {result_text}")

    is_zh = is_chinese(input_text)
    if image_prompt is None:
        if is_zh:
            result_text.append((input_text, '图片为空!请上传图片并重试。'))
        else:
            result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
        return input_text, result_text
    elif input_text == "":
        result_text.append((input_text, 'Text empty! Please enter text and retry.'))
        return "", result_text                

    request_para = {"temperature": temperature, "top_p": top_p}
    image = Image.open(image_prompt)
    try:
        answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
    except Exception as e:
        print(f"error: {e}")
        if is_zh:
            result_text.append((input_text, '超时!请稍等几分钟再重试。'))
        else:
            result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
        return "", result_text

    result_text.append((input_text, answer))
    print(result_text)
    return "", result_text


DESCRIPTION = '''# <a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a>'''

MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'

NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'


def clear_fn(value):
    return "", [("", "Hi, What do you want to know about this image?")], None

def clear_fn2(value):
    return [("", "Hi, What do you want to know about this image?")]


def main(args):
    gr.close_all()
    global model, tokenizer
    model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
    
    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)
        with gr.Row():
            with gr.Column(scale=4.5):
                with gr.Group():
                    input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
                    with gr.Row():
                        run_button = gr.Button('Generate')
                        clear_button = gr.Button('Clear')

                    image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
                with gr.Row():
                    temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
                    top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
                with gr.Group():
                    with gr.Row():
                        maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
            with gr.Column(scale=5.5):
                result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)

        gr.Markdown(NOTES)

        print(gr.__version__)
        run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
                         outputs=[input_text, result_text])
        input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
                         outputs=[input_text, result_text])
        clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
        image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
        image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])

        print(gr.__version__)

    demo.queue(concurrency_count=10)
    demo.launch(share=args.share)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()

    main(args)

Gradio 是一个用于构建交互式应用程序的 Python 框架,它使得创建用户友好的界面和功能强大的机器学习模型变得简单和快速。在上述代码中,使用了 Gradio 框架来构建一个交互式应用程序,用于生成文本描述图片。

对于小程序后端,我们使用Python的Flask框架实现相似的功能,代码如下:

import os
from flask import Flask, request, jsonify
from PIL import Image
import torch
from model import is_chinese, get_infer_setting, generate_input, chat

app = Flask(__name__)
UPLOAD_FOLDER = 'uploads'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)


def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
    input_para = {
        "max_length": 2048,
        "min_length": 50,
        "temperature": 0.8,
        "top_p": 0.4,
        "top_k": 100,
        "repetition_penalty": 1.2
    }
    input_para.update(request_data)

    input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
    input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
    with torch.no_grad():
        answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image,
                                  max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'],
                                  top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
    return answer


@app.route('/process', methods=['POST'])
def process():
    input_text = request.form.get('text')
    image_file = request.files.get('image')

    if not input_text or not image_file:
        return jsonify({'error': 'Missing text or image'}), 400

    # 保存上传的图片
    image_path = os.path.join(UPLOAD_FOLDER, image_file.filename)
    image_file.save(image_path)
    image = Image.open(image_path)

    is_zh = is_chinese(input_text)
    history = []
    request_data = {
        "temperature": float(request.form.get('temperature', 0.8)),
        "top_p": float(request.form.get('top_p', 0.4))
    }

    try:
        answer = generate_text_with_image(input_text, image, history, request_data, is_zh)
    except Exception as e:
        print(f"error: {e}")
        return jsonify({'error': 'Error processing request'}), 500

    return jsonify({'result': answer})


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()

    global model, tokenizer
    model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)

    app.run(debug=True, host='0.0.0.0', port=5000)

这段代码使用 Flask 框架构建了一个简单的 Web 服务,接收客户端上传的文本和图片,通过调用预训练的深度学习模型生成文本描述,并返回结果。主要包括设置上传目录、定义生成文本描述的函数、处理请求的路由以及启动应用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值