背景:
面临一个银行票据识别任务。目标是:将一行金额或日期描述文本作OCR识别。
由于数据包含手写体和多版本多字体机打文字,采用单文字拼接方式做数据增强。
数据增强思路:
1、将单行文字图像切成单字分别存储,得到某个字的多个表达形式集合。
2、对于每一条模拟label,在某字的表达形式集合中采样。
3、拼接字生成增强数据。
涉及到一些问题的解决:
1、cv2 读写中文目录报错的问题。
2、切分成字的过程要考虑原始图像的明暗变化,实现不同明暗程度下的单字样本采集。
3、单字被切破的情况可通过低宽度碎片合并实现一定程度的修补。
基于垂直投影的单行文字图像_字分割代码与效果:
代码:
def gather(img_thre, stride = 4):
h, w = img_thre.shape
mat = np.zeros_like(img_thre)
mat = np.hstack((mat, np.zeros((h, stride))))
mat = np.vstack((mat, np.zeros((stride, w + stride))))
for x in range(0, h, stride):
for y in range(0, w, stride):
mat[x, y] = np.mean(img_thre[(x - stride):(x + stride), (y-stride):(y+stride)]) * img_thre[x, y]
return mat[:h,:w]
# 单列黑色或白色(文字像素)比例
BW_RATIO = 0.91
# 图片片段长度阈值
KEEP_SEGL = 2
# 合并阈值
CONCAT_SEGL = 13
# 二值化阈值
BINARY_THRE = 130
def divice_image(ori_img):
# 1、读取图像,并把图像转换为灰度图像并显示
img = cv2.imread(ori_img) # 读取图片
# 2、resize
img = cv2.resize(img,(img.shape[1], 32))
# cv2.imshow("img",img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 3、将灰度图像二值化,设定阈值是130
img_thre = img_gray
cv2.threshold(img_gray, BINARY_THRE, 255, cv2.THRESH_BINARY_INV, img_thre)
# 加上该部分可以增强文字区域的像素积累。
mat = gather(img_thre, stride = 4)
img_thre = img_thre + mat
img_thre[img_thre > 0] = 255
# 4、分割字符
white = np.count_nonzero(img_thre==255,axis=0)
black = np.count_nonzero(img_thre==0,axis=0)
white_max = max(white)
black_max = max(black)
# False表示白底黑字;True表示黑底白字
arg = True if black_max > white_max else False
# 分割图像
def _find_end(start_):
end_ = start_ + 1
for m in range(start_ + 1, width - 1):
if (black[m] if arg else white[m]) > (BW_RATIO * black_max if arg else BW_RATIO * white_max):
end_ = m
break
return end_
def _get_len(se):
return se[1] - se[0]
def _seg_start_end_pos():
n, start, end= 1, 1, 2
start_end = []
while n < width - 2:
n += 1
if (white[n] if arg else black[n]) > ((1 - BW_RATIO) * white_max if arg else (1 - BW_RATIO) * black_max):
# 上面这些判断用来辨别是白底黑字还是黑底白字
start = n
end = _find_end(start)
n = end
if end - start > KEEP_SEGL:
start_end.append((start, end))
tmp = [start_end[0]]
for item in start_end[1:]:
if _get_len(tmp[-1]) < CONCAT_SEGL:
tmp[-1] = (tmp[-1][0], item[1])
else:
tmp.append(item)
return tmp
start_end = _seg_start_end_pos()
return [img[1:height, s-2: e+2] for (s,e) in start_end]