Datawhale干货
作者:安晟、袁明坤,Datawhale成员
在CV领域中,transformer除了分类还能做什么?本文将采用一个单词识别任务数据集,讲解如何使用transformer实现一个简单的OCR文字识别任务,并从中体会transformer是如何应用到除分类以外更复杂的CV任务中的。全文分为四部分:
一、数据集简介与获取
二、数据分析与关系构建
三、如何将transformer引入OCR
四、训练框架代码讲解
注:本文围绕如何设计模型和训练架构来解决OCR任务,文章含完整实践,代码很长建议收藏。不熟悉transformer的小伙伴可以点击这里回顾。
整个文字识别任务中,主要包括以下几个文件:
- analysis_recognition_dataset.py (数据集分析脚本)
- ocr_by_transformer.py (OCR任务训练脚本)
- transformer.py (transformer模型文件)
- train_utils.py (训练相关辅助函数,loss、optimizer等)
其中 ocr_by_transformer.py 为主要的训练脚本,其依托 train_utils.py 和 transformer.py 两个文件构建 transformer 来完成字符识别模型的训练
一、数据集简介与获取
本文使用的数据集基于ICDAR2015 Incidental Scene Text
中的 Task 4.3: Word Recognition
,这是一个著名的自然场景下文本识别数据集,本次用来进行单词识别任务,我们去掉了其中一些图片,来简化这个实验的难度,因此本文的数据集与原始数据集略有差别。
为了能够更好的进行数据共享和版本管控,我们选择在线调用数据集,将简化后的数据集存放在专门的数据共享平台,数据开源地址: https://gas.graviti.cn/dataset/datawhale/ICDAR2015 ,有相关问题可以直接在数据集讨论区交流。
该数据集包含了众多自然场景图像中出现的文字区域,数据中训练集含有4326张图像,测试集含有1992张图像,他们都是从原始大图中依据文字区域的bounding box裁剪出来的,图像中的文字基本处于图片中心位置。
数据集中图像类似如下样式:
word_104.png, "Optical" |
---|
数据本身以图像展示,对应的标签存储在CLASSIFICATION
中,后文代码中标签获取,将直接得到一个包含所有字符的列表,这也是为了方便标签易用性进行的存储选择。
下面简单介绍数据集的快速使用:
本地下载安装 tensorbay
pip3 install tensorbay
打开本文数据集链接:https://gas.graviti.cn/dataset/datawhale/ICDAR2015
将数据集fork到自己账户下
点击网页上方开发者工具 --> AccessKey --> 新建一个AccessKey --> 复制这个Key
from tensorbay import GAS
from tensorbay.dataset import Dataset
# GAS凭证
KEY = 'Accesskey-***************80a' # 添加自己的AccessKey
gas = GAS(KEY)
# 获取数据集
dataset = Dataset("ICDAR2015", gas)
# dataset.enable_cache('./data') # 开启本语句,选择将数据建立本地缓存
# 训练集和验证集
train_segment = dataset["train"]valid_segment = dataset['valid']
# 数据及标签
for data in train_segment:
# 图像数据
img = data.open()
# 图像标签
label = data.label.classification.category
break
通过以上代码获取的图像及标签形式如下:
img | label |
---|---|
['C', 'A', 'U', 'T', 'I', 'O', 'N'] |
通过以上简单代码便可以快速获取图像数据及标签,但程序每次运行都会自动去平台下载数据,因此耗时较长,建议开启本地缓存,一次下载多次使用,当不再使用数据时便可以将数据删除。
二、数据分析与关系构建
开始实验前,我们先对数据进行简单分析,只有对数据的特性足够了解,才能够更好的搭建出baseline,在训练中少走弯路。
运行下面代码,即可一键完成对于数据集的简单分析:
python analysis_recognition_dataset.py
具体地,这个脚本所做的工作包括:对数据进行标签字符统计(有哪些字符、每个字符出现次数多少)、最长标签长度统计,图像尺寸分析等,并且构建字符标签的映射关系文件 lbl2id_map.txt
。
下面我们来一点点看代码:
注:本文代码开源地址:
https://github.com/datawhalechina/dive-into-cv-pytorch/tree/master/code/chapter06_transformer/6.2_recognition_by_transformer(online_dataset)
首先完成准备工作,导入需要的库,并设置好相关目录或文件的路径
import os
from PIL import Image
import tqdm
from tensorbay import GAS
from tensorbay.dataset import Dataset
# GAS凭证
KEY = 'Accesskey-************************480a' # 添加自己的AccessKey
gas = GAS(KEY)
# 数据集获取并本地缓存
dataset = Dataset("ICDAR2015", gas)
dataset.enable_cache('./data') # 数据缓存地址
# 获取训练集和验证集
train_segment = dataset["train"]
valid_segment = dataset['valid']
# 中间文件存储路径,存储标签字符与其id的映射关系
base_data_dir = './'
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')
2.1 标签最长字符个数统计
首先统计数据集最长label中包含的字符数量,此处要将训练集和验证集中的最长标签都进行统计,进而得到最长标签所含字符。
def statistics_max_len_label(segment):
"""
统计标签中最长的label所包含的字符数
"""
max_len = -1
for data in segment:
lbl_str = data.label.classification.category # 获取标签
lbl_len = len(lbl_str)
max_len = max_len if max_len > lbl_len else lbl_len
return max_len
train_max_label_len = statistics_max_len_label(train_segment) # 训练集最长label
valid_max_label_len = statistics_max_len_label(valid_segment) # 验证集最长label
max_label_len = max(train_max_label_len, valid_max_label_len) # 全数据集最长label
print(f"数据集中包含字符最多的label长度为{max_label_len}")
数据集中最长label含有21个字符,这将为后面transformer模型搭建时的时间步长度的设置提供参考。
2.2 标签所含字符统计
下面代码查看数据集中出现过的所有字符:
def statistics_label_cnt(segment, lbl_cnt_map):
"""
统计标签文件中label都包含哪些字符以及各自出现的次数
lbl_cnt_map : 记录标签中字符出现次数的字典
"""
for data in segment:
lbl_str = data.label.classification.category # 获取标签
for lbl in lbl_str:
if lbl not in lbl_cnt_map.keys():
lbl_cnt_map[lbl] = 1
else:
lbl_cnt_map[lbl] += 1
lbl_cnt_map = dict() # 用于存储字符出现次数的字典
statistics_label_cnt(train_segment, lbl_cnt_map) # 训练集中字符出现次数统计
print("训练集中label中出现的字符:")
print(lbl_cnt_map)
statistics_label_cnt(valid_segment, lbl_cnt_map) # 训练集和验证集中字符出现次数统计
print("训练集+验证集label中出现的字符:")
print(lbl_cnt_map)
输出结果为:
训练集中label中出现的字符:
{'C': 593, 'A': 1189, 'U': 319, 'T': 896, 'I': 861, 'O': 965, 'N': 785, 'D': 383, 'W': 179, 'M': 367, 'E': 1423, 'X': 110, '$': 46, '2': 121, '4': 44, 'L': 745, 'F': 259, 'P': 389, 'R': 836, 'S': 1164, 'a': 843, 'v': 123, 'e': 1057, 'G': 345, "'": 51, 'r': 655, 'k': 96, 's': 557, 'i': 651, 'c': 318, 'V': 158, 'H': 391, '3': 50, '.': 95, '"': 8, '-': 68, ',': 19, 'Y': 229, 't': 563, 'y': 161, 'B': 332, 'u': 293, 'x': 27, 'n': 605, 'g': 171, 'o': 659, 'l': 408, 'd': 258, 'b': 88, 'p': 197, 'K': 163, 'J': 72, '5': 80, '0': 203, '1': 186, 'h': 299, '!': 51, ':': 19, 'f': 133, 'm': 202, '9': 66, '7': 45, 'j': 15, 'z': 12, '´': 3, 'Q': 19, 'Z': 29, '&': 9, ' ': 50, '8': 47, '/': 24, '#': 16, 'w': 97, '?': 5, '6': 40, '[': 2, ']': 2, 'É': 1, 'q': 3, ';': 3, '@': 4, '%': 28, '=': 1, '(': 6, ')': 5, '+': 1}
训练集+验证集label中出现的字符:
{'C': 893, 'A': 1827, 'U': 467, 'T': 1315, 'I': 1241, 'O': 1440, 'N': 1158, 'D': 548, 'W': 288, 'M': 536, 'E': 2215, 'X': 181, '$': 57, '2': 141, '4': 53, 'L': 1120, 'F': 402, 'P': 582, 'R': 1262, 'S': 1752,