from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
import requests
feature_extractor = AutoImageProcessor.from_pretrained("MixTex/ZhEn-Latex-OCR")
tokenizer = AutoTokenizer.from_pretrained("MixTex/ZhEn-Latex-OCR", max_len=296)
model = VisionEncoderDecoderModel.from_pretrained("MixTex/ZhEn-Latex-OCR")
#imgen = Image.open(requests.get('https://cdn-uploads.huggingface.co/production/uploads/62dbaade36292040577d2d4f/eOAym7FZDsjic_8ptsC-H.png', stream=True).raw)
#print(tokenizer.decode(model.generate(feature_extractor(imgen, return_tensors="pt").pixel_values)[0]).replace('\\[','\\begin{align*}').replace('\\]','\\end{align*}'))
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset
from PIL import Image
import torch
dataframe = load_dataset("MixTex/Pseudo-Latex-ZhEn-1")
from torch.utils.data import Dataset
class MixTexDataset(Dataset):
def __init__(self, dataframe, tokenizer, feature_extractor, max_length=256):
self.dataframe = dataframe
self.tokenizer = tokenizer
self.max_length = max_length
self.feature_extractor = feature_extractor
def __len__(self):
return len(self.dataframe['train'])
def __getitem__(self, idx):
image = self.dataframe['train'][idx]['image'].convert("RGB")
target_text = self.dataframe['train'][idx]['text']
pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
target = self.tokenizer(target_text, padding="max_length", max_length=self.max_length, truncation=True).input_ids
labels = [label if label != self.tokenizer.pad_token_id else -100 for label in target]
return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
traindataset = MixTexDataset(dataframe, tokenizer, feature_extractor= feature_extractor)
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
per_device_train_batch_size=12,
predict_with_generate=True,
logging_dir='./logs',
learning_rate=5e-5,
save_total_limit=1,
logging_steps=100,
save_steps=500,
num_train_epochs=3,
fp16=True,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=traindataset,
)
trainer.train()