(1)Latex_OCR项目:Encoder,Decoder,get_encoder,get_decoder。
上次我们将服务器上已经微调好的模型下载到本地,并在本地运行推理,测试微调后模型的性能。测试结果证明微调后,模型解数学题的能力提升了。
本次我想实现了一个新项目,能够将图片里的公式转化为Latex格式。方便用户去提取图片里的公式,进而和我们的大模型交互。
这里实现了模型架构的搭建。Encoder采用ViT,Decoder采用Transformer。
1.代码
Encoder:ViT。
import torch
import torch.nn as nn
from x_transformers import Encoder
from einops import rearrange, repeat
class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_width,
max_height,
patch_size,
attn_layers,
channels=1,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (max_width // patch_size)*(max_height // patch_size)
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.max_width = max_width
self.max_height = max_height
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None
def forward(self, img, **kwargs):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
h, w = torch.tensor(img.shape[2:])//p
pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embedding[:, pos_emb_ind]
x = self.dropout(x)
x = self.attn_layers(x, **kwargs)
x = self.norm(x)
return x
相应的get_encoder()方法。
def get_encoder(args):
return ViTransformerWrapper(
max_width=args.max_width,
max_height=args.max_height,
channels=args.channels,
patch_size=args.patch_size,
emb_dropout=args.get('emb_dropout', 0),
attn_layers=Encoder(
dim=args.dim,
depth=args.encoder_depth,
heads=args.heads,
)
)
我们还设计了一个改进版的ViT,网络架构中加入了更多层的注意力机制:
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from einops import repeat
class CustomVisionTransformer(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, *args, **kwargs):
super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs)
self.height, self.width = img_size
self.patch_size = patch_size
def forward_features(self, x):
B, c, h, w = x.shape
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
h, w = h//self.patch_size, w//self.patch_size
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embed[:, pos_emb_ind]
#x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
相对应的get_encoder()方法:
def get_encoder(args):
backbone = ResNetV2(
layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
preact=False, stem_type='same', conv_layer=StdConv2dSame)
min_patch_size = 2**(len(args.backbone_layers)+1)
def embed_layer(**x):
ps = x.pop('patch_size', min_patch_size)
assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone)
encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width),
patch_size=args.patch_size,
in_chans=args.channels,
num_classes=0,
embed_dim=args.dim,
depth=args.encoder_depth,
num_heads=args.heads,
embed_layer=embed_layer
)
return encoder
Decoder:Transformer。
import torch
import torch.nn.functional as F
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p
from x_transformers import TransformerWrapper, Decoder
class CustomARWrapper(AutoregressiveWrapper):
def __init__(self, *args, **kwargs):
super(CustomARWrapper, self).__init__(*args, **kwargs)
@torch.no_grad()
def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)
if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]
# print('arw:',out.shape)
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)
if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
get_decoder()方法。
def get_decoder(args):
return CustomARWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
attn_layers=Decoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
**args.decoder_args
)),
pad_value=args.pad_token)
2.说明
2.1原理
这是一个深度学习模型,网络为“编码器--解码器”结构。编码器为ViT,解码器为Transformer。
数据集来自网络,包括wikipedia(Wikipedia)CSDN等各类博客等。数据的获取与处理由金同学完成。
2.2项目的下一步
完成模型的训练(train)和测试(eval)过程,并将训练好的模型的调用等代码进行封装。