目录
三.Huggingface调用GOT Weights实现OCR
一.GOT-OCR-2.0模型介绍
GOT-OCR-2.0是以LVLM大视觉语言模型驱动的OCR识别模型,是一个端到端的通用模型:
1.补充了传统OCR系统需要通过整合多个领域专家网络来完成OCR任务。
2.扩展了基于LVLM的OCR模型的英文场景限制,可以更准确地识别中文场景。
3.模型结构:基于VitDet的预训练编码器---->125M OPT---->预训练的Qwen-0.5B---->Qwen-0.5B解码器
4.训练阶段:
a)VitDet编码器预训练阶段: 使用了大约5M对图像-文本,包括3M个场景文本(英文中文场景各一半)OCR数据和2M个文档OCR数据。
b)联合Qwen-0.5B预训练阶段:使用多种格式数据:普通OCR、Mathpix-markdown格式化数据(数学公式、分子公式、表、Mathpix格式)、更通用的数据(乐谱、几何图形、图表)
c)解码器训练阶段:为了实现细粒度、多页和动态分辨率OCR,使用了不同的数据集。
二.官方源码地址
https://github.com/Ucas-HaoranWei/GOT-OCR2.0
三.Huggingface调用GOT Weights实现OCR
https://huggingface.co/stepfun-ai/GOT-OCR2_0(Huggingface存储GOT Weights地址)
1.下载环境
pip install transformers
pip install tiktoken
pip install verovio
pip install accelerate
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
2.在Huggingface下载文件
3.调用
把 'ucaslcl/GOT-OCR2_0'替换成自己的文件夹路径
from transformers import AutoModel, AutoTokenizer
# 'ucaslcl/GOT-OCR2_0'是存放地址的文件夹
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
# input your test image
image_file = 'xxx.jpg'
# plain texts OCR
res = model.chat(tokenizer, image_file, ocr_type='ocr')
# format texts OCR:
# res = model.chat(tokenizer, image_file, ocr_type='format')
# fine-grained OCR:
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='')
# multi-crop OCR:
# res = model.chat_crop(tokenizer, image_file, ocr_type='ocr')
# res = model.chat_crop(tokenizer, image_file, ocr_type='format')
# render the formatted OCR results:
# res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html')
print(res)
4.结果
得到识别结果是字符串格式