MixTex-一款好用的latex转换器

from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
import requests

feature_extractor = AutoImageProcessor.from_pretrained("MixTex/ZhEn-Latex-OCR")
tokenizer = AutoTokenizer.from_pretrained("MixTex/ZhEn-Latex-OCR", max_len=296)
model = VisionEncoderDecoderModel.from_pretrained("MixTex/ZhEn-Latex-OCR")
#imgen = Image.open(requests.get('https://cdn-uploads.huggingface.co/production/uploads/62dbaade36292040577d2d4f/eOAym7FZDsjic_8ptsC-H.png', stream=True).raw)
#print(tokenizer.decode(model.generate(feature_extractor(imgen, return_tensors="pt").pixel_values)[0]).replace('\\[','\\begin{align*}').replace('\\]','\\end{align*}'))
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset
from PIL import Image
import torch

dataframe = load_dataset("MixTex/Pseudo-Latex-ZhEn-1")
from torch.utils.data import Dataset
class MixTexDataset(Dataset):
    def __init__(self, dataframe, tokenizer, feature_extractor, max_length=256):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataframe['train'])

    def __getitem__(self, idx):
        image = self.dataframe['train'][idx]['image'].convert("RGB")
        target_text = self.dataframe['train'][idx]['text']
        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
        target = self.tokenizer(target_text, padding="max_length", max_length=self.max_length, truncation=True).input_ids
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in target]
        return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
traindataset = MixTexDataset(dataframe, tokenizer, feature_extractor= feature_extractor)
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=12,
    predict_with_generate=True,
    logging_dir='./logs',
    learning_rate=5e-5,
    save_total_limit=1,
    logging_steps=100,
    save_steps=500,
    num_train_epochs=3,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=traindataset,
)
trainer.train()

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卡卡_R-Python

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值