(2)Latex_OCR项目:train,train_resize,eval,cli
本次我想实现了一个新项目,能够将图片里的公式转化为Latex格式。方便用户去提取图片里的公式,进而和我们的大模型交互。
上次已经实现了Encoder,Decoder,get_encoder,get_decoder的代码,本次实现train,train_resize,eval,cli的代码。
1.代码
train:
from pix2tex.dataset.dataset import Im2LatexDataset
import os
import argparse
import logging
import yaml
import torch
from munch import Munch
from tqdm.auto import tqdm
import wandb
import torch.nn as nn
from pix2tex.eval import evaluate
from pix2tex.models import get_model
# from pix2tex.utils import *
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check
def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
gpu_memory_check(model, args)
max_bleu, max_token_acc = 0, 0
out_path = os.path.join(args.model_path, args.name)
os.makedirs(out_path, exist_ok=True)
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
def save_models(e, step=0):
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step)))
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))
opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas)
scheduler = get_scheduler(args.scheduler)(opt, step_size=args.lr_step, gamma=args.gamma)
microbatch = args.get('micro_batchsize', -1)
if microbatch == -1:
microbatch = args.batchsize
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
total_loss = 0
for j in range(0, len(im), microbatch):
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
loss.backward() # data parallism loss is a vector
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
scheduler.step()
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
parser.add_argument('--debug', action='store_true', help='DEBUG')
parser.add_argument('--resume', help='path to checkpoint folder', action='store_true')
parsed_args = parser.parse_args()
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/debug.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params), **vars(parsed_args))
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
seed_everything(args.seed)
if args.wandb:
if not parsed_args.resume:
args.id = wandb.util.generate_id()
wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id)
args = Munch(wandb.config)
train(args)
train_resize:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
import numpy as np
from PIL import Image
import cv2
import imagesize
import yaml
from tqdm.auto import tqdm
from pix2tex.utils import *
from pix2tex.dataset.dataset import *
from munch import Munch
import argparse
from typing import Tuple
def prepare_data(dataloader: Im2LatexDataset) -> Tuple[torch.tensor, torch.tensor]:
_, ims = dataloader.pairs[dataloader.i-1].T
images = []
scale = None
c = 0
width, height = imagesize.get(ims[0])
while True:
c += 1
s = np.array([width, height])
scale = 5*(np.random.random()+.02)
if all((s*scale) <= dataloader.max_dimensions[0]) and all((s*scale) >= 16):
break
if c > 25:
return None, None
x, y = 0, 0
for path in list(ims):
im = Image.open(path)
modes = [Image.Resampling.BICUBIC,
Image.Resampling.BILINEAR]
if scale < 1:
modes.append(Image.Resampling.LANCZOS)
m = modes[int(len(modes)*np.random.random())]
im = im.resize((int(width*scale), int(height*scale)), m)
try:
im = pad(im)
except:
return None, None
if im is None:
print(path, 'not found!')
continue
im = np.array(im)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
images.append(dataloader.transform(image=im)['image'][:1])
if images[-1].shape[-1] > x:
x = images[-1].shape[-1]
if images[-1].shape[-2] > y:
y = images[-1].shape[-2]
if x > dataloader.max_dimensions[0] or y > dataloader.max_dimensions[1]:
return None, None
for i in range(len(images)):
h, w = images[i].shape[1:]
images[i] = F.pad(images[i], (0, x-w, 0, y-h), value=0)
try:
images = torch.cat(images).float().unsqueeze(1)
except RuntimeError as e:
#print(e, 'Images not working: %s' % (' '.join(list(ims))))
return None, None
dataloader.i += 1
labels = torch.tensor(width//32-1).repeat(len(ims)).long()
return images, labels
def val(val: Im2LatexDataset, model: ResNetV2, num_samples=400, device='cuda') -> float:
model.eval()
c, t = 0, 0
iter(val)
with torch.no_grad():
for i in range(num_samples):
im, l = prepare_data(val)
if im is None:
continue
p = model(im.to(device)).argmax(-1).detach().cpu().numpy()
c += (p == l[0].item()).sum()
t += len(im)
model.train()
return c/t
def main(args):
# data
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(batchsize=args.batchsize, test=False, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device)
valloader = Im2LatexDataset().load(args.valdata)
valloader.update(batchsize=args.batchsize, test=True, max_dimensions=args.max_dimensions, keep_smaller_batches=True, device=args.device)
# model
model = ResNetV2(layers=[2, 3, 3], num_classes=int(max(args.max_dimensions)//32), global_pool='avg', in_chans=args.channels, drop_rate=.05,
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device)
if args.resume:
model.load_state_dict(torch.load(args.resume))
opt = Adam(model.parameters(), lr=args.lr)
crit = nn.CrossEntropyLoss()
sched = OneCycleLR(opt, .005, total_steps=args.num_epochs*len(dataloader))
global bestacc
bestacc = val(valloader, model, args.valbatches, args.device)
def train_epoch(sched=None):
iter(dataloader)
dset = tqdm(range(len(dataloader)))
for i in dset:
im, label = prepare_data(dataloader)
if im is not None:
if im.shape[-1] > dataloader.max_dimensions[0] or im.shape[-2] > dataloader.max_dimensions[1]:
continue
opt.zero_grad()
label = label.to(args.device)
pred = model(im.to(args.device))
loss = crit(pred, label)
if i % 2 == 0:
dset.set_description('Loss: %.4f' % loss.item())
loss.backward()
opt.step()
if sched is not None:
sched.step()
if (i+1) % args.sample_freq == 0 or i+1 == len(dset):
acc = val(valloader, model, args.valbatches, args.device)
print('Accuracy %.2f' % (100*acc), '%')
global bestacc
if acc > bestacc:
torch.save(model.state_dict(), args.out)
bestacc = acc
for _ in range(args.num_epochs):
train_epoch(sched)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train size classification model')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--resume', help='path to checkpoint folder', type=str, default='')
parser.add_argument('--out', type=str, default='checkpoints/image_resizer.pth', help='output destination for trained model')
parser.add_argument('--num_epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--batchsize', type=int, default=10)
parsed_args = parser.parse_args()
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/debug.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params), **vars(parsed_args))
args.update(**vars(parsed_args))
main(args)
eval
from pix2tex.dataset.dataset import Im2LatexDataset
import argparse
import logging
import yaml
import numpy as np
import torch
from torchtext.data import metrics
from munch import Munch
from tqdm.auto import tqdm
import wandb
from Levenshtein import distance
from pix2tex.models import get_model, Model
from pix2tex.utils import *
def detokenize(tokens, tokenizer):
toks = [tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
for b in range(len(toks)):
for i in reversed(range(len(toks[b]))):
if toks[b][i] is None:
toks[b][i] = ''
toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
if toks[b][i] in (['[BOS]', '[EOS]', '[PAD]']):
del toks[b][i]
return toks
@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset
Args:
model (torch.nn.Module): the model
dataset (Im2LatexDataset): test dataset
args (Munch): arguments
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
Returns:
Tuple[float, float, float]: BLEU score of validation set, normed edit distance, token accuracy
"""
assert len(dataset) > 0
device = args.device
log = {}
bleus, edit_dists, token_acc = [], [], []
bleu_score, edit_distance, token_accuracy = 0, 1, 0
pbar = tqdm(enumerate(iter(dataset)), total=len(dataset))
for i, (seq, im) in pbar:
if seq is None or im is None:
continue
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.generate(im.to(device), temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
for predi, truthi in zip(token2str(dec, dataset.tokenizer), token2str(seq['input_ids'], dataset.tokenizer)):
ts = post_process(truthi)
if len(ts) > 0:
edit_dists.append(distance(post_process(predi), ts)/len(ts))
dec = dec.cpu()
tgt_seq = seq['input_ids'][:, 1:]
shape_diff = dec.shape[1]-tgt_seq.shape[1]
if shape_diff < 0:
dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token)
elif shape_diff > 0:
tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token)
mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token)
tok_acc = (dec == tgt_seq)[mask].float().mean().item()
token_acc.append(tok_acc)
pbar.set_description('BLEU: %.3f, ED: %.2e, ACC: %.3f' % (np.mean(bleus), np.mean(edit_dists), np.mean(token_acc)))
if num_batches is not None and i >= num_batches:
break
if len(bleus) > 0:
bleu_score = np.mean(bleus)
log[name+'/bleu'] = bleu_score
if len(edit_dists) > 0:
edit_distance = np.mean(edit_dists)
log[name+'/edit_distance'] = edit_distance
if len(token_acc) > 0:
token_accuracy = np.mean(token_acc)
log[name+'/token_acc'] = token_accuracy
if args.wandb:
# samples
pred = token2str(dec, dataset.tokenizer)
truth = token2str(seq['input_ids'], dataset.tokenizer)
table = wandb.Table(columns=["Truth", "Prediction"])
for k in range(min([len(pred), args.test_samples])):
table.add_data(post_process(truth[k]), post_process(pred[k]))
log[name+'/examples'] = table
wandb.log(log)
else:
print('\n%s\n%s' % (truth, pred))
print('BLEU: %.2f' % bleu_score)
return bleu_score, edit_distance, token_accuracy
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test model')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('-c', '--checkpoint', default=None, type=str, help='path to model checkpoint')
parser.add_argument('-d', '--data', default='dataset/data/val.pkl', type=str, help='Path to Dataset pkl file')
parser.add_argument('--no-cuda', action='store_true', help='Use CPU')
parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size')
parser.add_argument('--debug', action='store_true', help='DEBUG')
parser.add_argument('-t', '--temperature', type=float, default=.333, help='sampling emperature')
parser.add_argument('-n', '--num-batches', type=int, default=None, help='how many batches to evaluate on. Defaults to None (all)')
parsed_args = parser.parse_args()
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/config.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params))
args.testbatchsize = parsed_args.batchsize
args.wandb = False
args.temperature = parsed_args.temperature
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
seed_everything(args.seed if 'seed' in args else 42)
model = get_model(args)
if parsed_args.checkpoint is None:
with in_model_path():
parsed_args.checkpoint = os.path.realpath('checkpoints/weights.pth')
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))
dataset = Im2LatexDataset().load(parsed_args.data)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
dataset.update(**valargs)
evaluate(model, dataset, args, num_batches=parsed_args.num_batches)
cli
from pix2tex.dataset.transforms import test_transform
import pandas.io.clipboard as clipboard
from PIL import ImageGrab
from PIL import Image
import os
from pathlib import Path
import sys
from typing import List, Optional, Tuple
import atexit
from contextlib import suppress
import logging
import yaml
import re
with suppress(ImportError, AttributeError):
import readline
import numpy as np
import torch
from torch._appdirs import user_data_dir
from munch import Munch
from transformers import PreTrainedTokenizerFast
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from pix2tex.dataset.latex2png import tex2pil
from pix2tex.models import get_model
from pix2tex.utils import *
# from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints
def minmax_size(img: Image, max_dimensions: Tuple[int, int] = None, min_dimensions: Tuple[int, int] = None) -> Image:
if max_dimensions is not None:
ratios = [a/b for a, b in zip(img.size, max_dimensions)]
if any([r > 1 for r in ratios]):
size = np.array(img.size)//max(ratios)
img = img.resize(size.astype(int), Image.BILINEAR)
if min_dimensions is not None:
# hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
padded_size = [max(img_dim, min_dim) for img_dim, min_dim in zip(img.size, min_dimensions)]
if padded_size != list(img.size): # assert hypothesis
padded_im = Image.new('L', padded_size, 255)
padded_im.paste(img, img.getbbox())
img = padded_im
return img
class LatexOCR:
image_resizer = None
last_pic = None
@in_model_path()
def __init__(self, arguments=None):
if arguments is None:
arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
logging.getLogger().setLevel(logging.FATAL) # 减少日志输出
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 控制日志输出级别
with open(arguments.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader) # 加载模型参数
# 解析并更新模型参数
self.args = parse_args(Munch(params))
self.args.update(**vars(arguments))
self.args.wandb = False # 禁用wandb
# 检查CUDA是否可用,并根据情况设置设备类型为cuda或cpu
self.args.device = 'cuda' if torch.cuda.is_available() and not self.args.no_cuda else 'cpu'
if not os.path.exists(self.args.checkpoint):
download_checkpoints()
# 获取模型并加载相应的权重
self.model = get_model(self.args)
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
self.model.eval()
# 如果当前目录下存在image_resizer.pth文件且未禁用大小调整,则将其载入作为image_resizer模型。
if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(self.args.device)
self.image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(self.args.checkpoint), 'image_resizer.pth'), map_location=self.args.device))
self.image_resizer.eval()
# 初始化一个PreTrainedTokenizerFast对象作为tokenizer
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.args.tokenizer)
@in_model_path()
def __call__(self, img=None, resize=True) -> str:
if type(img) is bool:
img = None
if img is None:
if self.last_pic is None:
return ''
else:
print('\nLast image is: ', end='')
img = self.last_pic.copy()
else:
self.last_pic = img.copy()
img = minmax_size(pad(img), self.args.max_dimensions, self.args.min_dimensions)
if (self.image_resizer is not None and not self.args.no_resize) and resize:
with torch.no_grad():
input_image = img.convert('RGB').copy()
r, w, h = 1, input_image.size[0], input_image.size[1]
for _ in range(10):
h = int(h * r) # height to resize
img = pad(minmax_size(input_image.resize((w, h), Image.Resampling.BILINEAR if r > 1 else Image.Resampling.LANCZOS), self.args.max_dimensions, self.args.min_dimensions))
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
w = (self.image_resizer(t.to(self.args.device)).argmax(-1).item()+1)*32
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
if (w == img.size[0]):
break
r = w/img.size[0]
else:
img = np.array(pad(img).convert('RGB'))
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(self.args.device)
dec = self.model.generate(im.to(self.args.device), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
try:
clipboard.copy(pred)
except:
pass
return pred
def output_prediction(pred, args):
TERM = os.getenv('TERM', 'xterm')
if not sys.stdout.isatty():
TERM = 'dumb'
try:
from pygments import highlight
from pygments.lexers import get_lexer_by_name
from pygments.formatters import get_formatter_by_name
if TERM.split('-')[-1] == '256color':
formatter_name = 'terminal256'
elif TERM != 'dumb':
formatter_name = 'terminal'
else:
formatter_name = None
if formatter_name:
formatter = get_formatter_by_name(formatter_name)
lexer = get_lexer_by_name('tex')
print(highlight(pred, lexer, formatter), end='')
except ImportError:
TERM = 'dumb'
if TERM == 'dumb':
print(pred)
if args.show or args.katex:
try:
if args.katex:
raise ValueError
tex2pil([f'$${pred}$$'])[0].show()
except Exception as e:
# render using katex
import webbrowser
from urllib.parse import quote
url = 'https://katex.org/?data=' + \
quote('{"displayMode":true,"leqno":false,"fleqn":false,"throwOnError":true,"errorColor":"#cc0000",\
"strict":"warn","output":"htmlAndMathml","trust":false,"code":"%s"}' % pred.replace('\\', '\\\\'))
webbrowser.open(url)
def predict(model, file, arguments):
img = None
if file:
try:
img = Image.open(os.path.expanduser(file))
except Exception as e:
print(e, end='')
else:
try:
img = ImageGrab.grabclipboard()
except NotImplementedError as e:
print(e, end='')
pred = model(img)
output_prediction(pred, arguments)
def check_file_path(paths:List[Path], wdir:Optional[Path]=None)->List[str]:
files = []
for path in paths:
if type(path)==str:
if path=='':
continue
path=Path(path)
pathsi = ([path] if wdir is None else [path, wdir/path])
for p in pathsi:
if p.exists():
files.append(str(p.resolve()))
elif '*' in path.name:
files.extend([str(pi.resolve()) for pi in p.parent.glob(p.name)])
return list(set(files))
2.效果展示
示例图片:
程序输出:
M={\frac{x_{1}+x_{2}+\cdot\cdot\cdot+x_{n}}{n}}
将程序输出的latex代码放入编辑器,检查是否正确:
如上图所示,程序输出的latex代码是正确的。
3.说明
3.1原理
这是一个深度学习模型,网络为“编码器--解码器”结构。编码器为ViT,解码器为Transformer。
数据集来自网络,包括wikipedia(Wikipedia)、CSDN等各类博客等。数据的获取与处理由金同学完成。
3.2目前的不足之处
当前网络更适合识别分辨率较小的图像。为了能识别较大的图像,我们还加入了一个神经网络。这个神经网络专门预测最合适的分辨率大小,然后将输入图像预处理为该分辨率。
但即使做了预处理操作,公式图像的识别依然不能保证100%正确。我们建议不要将公式图像放的过大。如果发现模型没有成功识别的话,可以尝试切换不同的分辨率。
ResNetV2:
class PreActBottleneck(nn.Module):
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm3 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
nn.init.zeros_(self.conv3.weight)
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.conv3(self.norm3(x))
x = self.drop_path(x)
return x + shortcut
class Bottleneck(nn.Module):
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
conv_layer=conv_layer, norm_layer=norm_layer)
else:
self.downsample = None
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.norm1 = norm_layer(mid_chs)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.norm2 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.norm3 = norm_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def zero_init_last(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
# residual
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.drop_path(x)
x = self.act3(x + shortcut)
return x
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(x))
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ResNetStage(nn.Module):
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
proj_layer = DownsampleAvg if avg_down else DownsampleConv
prev_chs = in_chs
self.blocks = nn.Sequential()
for block_idx in range(depth):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
def forward(self, x):
x = self.blocks(x)
return x
def is_stem_deep(stem_type):
return any([s in stem_type for s in ('deep', 'tiered')])
def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if is_stem_deep(stem_type):
# A 3 deep 3x3 conv stack as in ResNet V1D models
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
stem['norm1'] = norm_layer(stem_chs[0])
stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
stem['norm2'] = norm_layer(stem_chs[1])
stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
if not preact:
stem['norm3'] = norm_layer(out_chs)
else:
# The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem['pad'] = nn.ConstantPad2d(1, 0.)
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
elif 'same' in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
else:
# the usual PyTorch symmetric padding
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
return nn.Sequential(stem)
class ResNetV2(nn.Module):
def __init__(
self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last=False):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
self.stages.add_module(str(stage_idx), stage)
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
self.init_weights(zero_init_last=zero_init_last)
def init_weights(self, zero_init_last=True):
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
_load_weights(self, checkpoint_path, prefix)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
3.3该技术的下一步
设计读取公式图片并输出对应latex公式的前端,并与后端的模型进行连接。这是我们接下来打算去做的工作。