本范例我们使用经典的 CRNN+ CTC Loss 的OCR模型来识别验证码。
我们通过导入一个叫 captcha 的库来生成验证码。
我们生成验证码的字符由数字和大写字母组成。
项目参考:https://github.com/ypwhs/captcha_break
#!pip install captcha torchkeras
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor, to_pil_image
from tqdm import tqdm
import random
import numpy as np
import torchkeras
from pathlib import Path
from collections import OrderedDict
characters = '-' + '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' # 注:’-‘ 为[blank] 特殊字符
width, height = 192, 64
n_classes = len(characters)
txt_length = 4 #识别的验证码长度
seq_length = 12 #CRNN输出序列长度,一般要求 seq_length>=2*txt_length+1
一,准备数据
from captcha.image import ImageCaptcha
generator = ImageCaptcha(width=320, height=64,
fonts=[str(Path(torchkeras.__file__).parent/'assets'/'SimHei.ttf') ],
font_sizes=[40,45])
generator.generate_image('中国人民很行')
class CaptchaDataset(Dataset):
def __init__(self, characters, length,
width, height, input_length, label_length):
super(CaptchaDataset, self).__init__()
self.characters = characters
self.length = length
self.width = width
self.height = height
self.input_length = input_length
self.label_length = label_length
self.n_class = len(characters)
self.generator = ImageCaptcha(width=width, height=height)
def __len__(self):
return self.length
def __getitem__(self, index):
random_str = ''.join([random.choice(self.characters[1:]) for j in range(self.label_length)])
image = to_tensor(self.generator.generate_image(random_str))
target = torch.tensor([self.characters.find(x) for x in random_str], dtype=torch.long)
input_length = torch.full(size=(1, ), fill_value=self.input_length, dtype=torch.long)
target_length = torch.full(size=(1, ), fill_value=self.label_length, dtype=torch.long)
return image, target, input_length, target_length
batch_size = 128
ds_train= CaptchaDataset(characters, 100 * batch_size,
width, height, seq_length, txt_length)
ds_val = CaptchaDataset(characters, 20 * batch_size,
width, height, seq_length, txt_length)
dl_train = DataLoader(ds_train, batch_size=batch_size, num_workers=4)
dl_val = DataLoader(ds_val, batch_size=batch_size, num_workers=4)
ds_test = CaptchaDataset(characters, 1, width, height, seq_length, txt_length)
image, target, input_length, label_length = ds_test[0]
print(''.join([characters[x] for x in target]), input_length, label_length)
to_pil_image(image)
二,定义模型
class CRNN(nn.Module):
def __init__(self, n_classes, input_shape=(3, 64, 128)):
super().__init__()
self.input_shape = input_shape
channels = [32, 64, 128, 256, 256]
layers = [2, 2, 2, 2, 2]
kernels = [3, 3, 3, 3, 3]
pools = [2, 2, 2, 2, (2, 1)]
modules = OrderedDict()
def cba(name, in_channels, out_channels, kernel_size):
modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
padding=(1, 1) if kernel_size == 3 else 0)
modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
modules[f'relu{name}'] = nn.ReLU(inplace=True)
last_channel = 3
for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
for layer in range(1, n_layer + 1):
cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
last_channel = n_channel
modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
self.cnn = nn.Sequential(modules)
self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
self.fc = nn.Linear(in_features=256, out_features=n_classes)
def infer_features(self):
x = torch.zeros((1,)+self.input_shape)
x = self.cnn(x)
x = x.reshape(x.shape[0], -1, x.shape[-1])
return x.shape[1]
def forward(self, x):
x = self.cnn(x)
x = x.reshape(x.shape[0], -1, x.shape[-1])
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = self.fc(x)
return x
net = CRNN(n_classes, input_shape=(3, height, width))
inputs = torch.zeros((32, 3, height, width))
outputs = net(inputs)
print(outputs.shape) # LSTM默认输出的形状是 Length在前
net.cuda();
torch.Size([12, 32, 37])
三, 训练模型
# 解码函数和计算准确率函数
def decode_target(sequence):
return ''.join([characters[x] for x in sequence]).replace(' ', '')
def decode(sequence):
a = ''.join([characters[x] for x in sequence])
s = ''.join([x for j, x in enumerate(a[:-1]) if x != characters[0] and x != a[j+1]])
if len(s) == 0:
return ''
if a[-1] != characters[0] and s[-1] != a[-1]:
s += a[-1]
return s
def eval_acc(targets, preds):
preds_argmax = preds.detach().permute(1, 0, 2).argmax(dim=-1)
targets = targets.cpu().numpy()
preds_argmax = preds_argmax.cpu().numpy()
a = np.array([decode_target(gt) == decode(pred) for gt,
pred in zip(targets, preds_argmax)])
return a.mean()
import torch.nn.functional as F
from torchkeras import KerasModel
#我们覆盖KerasModel的StepRunner以实现自定义训练逻辑。
#注意这里把acc指标的结果写在了step_loss中以便和loss一样在Epoch上求平均,这是一个非常灵活而且有用的写法。
class StepRunner:
def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None,
optimizer = None, lr_scheduler = None
):
self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
self.accelerator = accelerator
if self.stage=='train':
self.net.train()
else:
self.net.eval()
def __call__(self, batch):
images, targets, input_lengths, target_lengths = batch
#loss
preds = self.net(images)
preds_log_softmax = F.log_softmax(preds, dim=-1)
loss = F.ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)
acc = eval_acc(targets,preds)
#backward()
if self.optimizer is not None and self.stage=="train":
self.accelerator.backward(loss)
self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.optimizer.zero_grad()
all_loss = self.accelerator.gather(loss).sum()
#losses
step_losses = {self.stage+"_loss":
all_loss.item(),
self.stage+'_acc':acc}
#metrics
step_metrics = {}
if self.stage=="train":
if self.optimizer is not None:
step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
else:
step_metrics['lr'] = 0.0
return step_losses,step_metrics
KerasModel.StepRunner = StepRunner
model = KerasModel(net,
loss_fn=None,
optimizer = torch.optim.AdamW(net.parameters(),lr = 2e-6)
)
model.load_ckpt('ctc_crnn.pt')
model.fit(
train_data = dl_train,
val_data= dl_val,
ckpt_path='ctc_crnn.pt',
epochs=30,
patience=10,
monitor="val_acc",
mode="max",
plot = True,
wandb = False,
callbacks=[visdis],
quiet = lambda epoch: epoch>5
)
四,评估模型
def display_fn(model):
model.eval()
right = True
while right:
image, target, input_length, label_length = ds_test[0]
output = model(image.unsqueeze(0).cuda())
output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
right = (decode_target(target) == decode(output_argmax[0]))
print('gt:', decode_target(target),' ','pred:', decode(output_argmax[0]))
display(to_pil_image(image))
from torchkeras.kerascallbacks import VisDisplay
visdis = VisDisplay(display_fn,model)
五,使用模型
def predict(model,image):
model.cuda()
tensor = to_tensor(image)
output = model(tensor.unsqueeze(0).cuda())
output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
preds = decode(output_argmax[0])
return preds
model.load_ckpt('ctc_crnn.pt')
generator = ImageCaptcha(width=width, height=height)
image = generator.generate_image('ABBD')
image
predict(model,image)
六,保存模型
torch.save(model.net.state_dict(),'best.pt')
公众号算法美食屋后台回复关键词:CRNN,获取本文notebook源代码和B站视频讲解。