简介
ABCNet:基于自适应贝塞尔曲线的实时端到端自然场景文字检测及识别网络
论文链接 : https://arxiv.org/abs/2002.10200
官方开源代码: https://github.com/aim-uofa/AdelaiDet
ABCNet: Real-Time Scene Text Spotting With Adaptive Bezier-Curve Network
作者 | Yuliang Liu, Hao Chen, Chunhua Shen, Tong He, Lianwen Jin, Liangwei Wang
单位 | 华南理工大学;阿德莱德大学;
代码 | https://github.com/Yuliang-Liu/bezier_curve_text_spotting
备注 | CVPR 2020 Oral
解读 | https://zhuanlan.zhihu.com/p/146276834
论文是2020 CVPR 收录, 贡献1)提出采用贝塞尔曲线来拟合任意形状文本,2)提出贝塞尔对齐方式更准确地提取文本实例 .
简介部分参考知乎解读即可,具体理论知识不再赘述,下面进入正题,环境配置+测试+训练
环境
# env
torch 1.4.0
torchvision 0.5.0
py362 cuda10.1
# 1.First install Detectron2
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2
# 2.
cd AdelaiDet
python setup.py build develop
如上,首先要安装Detectron2板块,具体 following the official guide: INSTALL.md.
第二步就是正式安装AdelaiDet,按指令编译即可。
以上环境在torch 1.4.0 + torchvision0.5.0 ,python3.6.2 .cuda10.1 的虚拟环境下配置,基于anaconda.本机实际cuda安装的10.2.
特别提示,torch和torchvision的版本必须一致,否则编译中途会报各种莫名其妙的错误。已踩坑。
最后,综合环境预览如下:
---------------------- ---------------------------------------------------------------------------------------------------------
sys.platform linux
Python 3.6.2 |Continuum Analytics, Inc.| (default, Jul 20 2017, 13:51:32) [GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
numpy 1.19.1
detectron2 0.2.1 @/home/gavin/MyProj/tempwork/ocr/AdelaiDet/detectron2/detectron2
Compiler GCC 5.4
CUDA compiler CUDA 10.2
detectron2 arch flags 6.1
DETECTRON2_ENV_MODULE <not set>
PyTorch 1.4.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torch
PyTorch debug build False
GPU available True
GPU 0 GeForce GTX 1080 Ti (arch=6.1)
CUDA_HOME /usr/local/cuda-10.2
Pillow 4.2.1
torchvision 0.5.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torchvision
torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5
fvcore 0.1.1.post20200716
cv2 4.4.0
---------------------- ---------------------------------------------------------------------------------------------------------
PyTorch built with:
- GCC 7.3
- Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- NNPACK is enabled
- CUDA Runtime 10.1
- NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
- CuDNN 7.6.3
- Magma 2.5.1
- Build settings: BLAS=MKL, BUILD_NAMEDTENSOR=OFF, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Wno-stringop-overflow, DISABLE_NUMA=1, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,
测试
假设你配好了环境,那么进行下面的代码测试即可:
# test
python demo/demo.py \
--config-file configs/BAText/TotalText/attn_R_50.yaml \
--input datasets/totaltext/test_images/ \
--opts MODEL.WEIGHTS tt_attn_R_50.pth
eg:
python demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input /media/gavin/home/gavin/DataSet/ocr/test/bdz --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
基于totalText测试情况如下:
可以看出检测出来的还是比较准确,但是漏检情况较多,可能跟数据集有关,我这里是拿的ICDAR2015里面的数据测试的。
测试的权重文件是官方训练的。
下面自行训练的CTW,检测效果上,漏检情况就好很多。结果见后面。
训练
最关键也是最麻烦的事情来了。
数据集准备
方法一:直接windows_label_tool工具标注的数据,这里是基于自己的数据进行标注,那么只需要执行一个转换即可,将windows_label_tool工具标注的格式转为abcnet训练的json格式。
参考这里。
# ABCNet 自定义数据集制作,将ICDAR15转为ABCNet标注格式 参考https://github.com/Yuliang-Liu/Curve-Text-Detector/tree/master/data
# 1.将labelme标注转为windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标,后面是文本内容
4
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,"DOUGLASTON"
50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,"E-313"
41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,"L164"
39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,"F.D.N.Y."
# 2. convert_ann_to_json:将生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
python convert_ann_to_json.py \
--ann-dir /path/to/gt \
--image-dir /path/to/image \
--dst-json-path train.json
eg:
python convert_ann_to_json.py --ann-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500_e2e_annos/ctw1500_e2e_test \
--image-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500/test/text_image \
--dst-json-path ./abc_json/test.json
json转换脚本
# -*- coding: utf-8 -*-
"""
@File : convert_ann_to_json.py
@Time : 2020-8-17 16:13
@Author : yizuotian
@Description : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
"""
import argparse
import json
import os
import sys
import cv2
import bezier_utils
import numpy as np
def gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):
"""
根据abcnet的gt标注生成coco格式的json标注
:param abc_gt_dir: windows_label_tool标注工具生成标注文件目录
:param abc_json_path: ABCNet训练需要json标注路径
:param image_dir:
:param classes_path: 类别文件路径
:return:
"""
# Desktop Latin_embed.
cV2 = [' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4',
'5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
'`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
dataset = {
'licenses': [],
'info': {},
'categories': [],
'images': [],
'annotations': []
}
with open(classes_path) as f:
classes = f.read().strip().split()
for i, cls in enumerate(classes, 1):
dataset['categories'].append({
'id': i,
'name': cls,
'supercategory': 'beverage',
'keypoints': ['mean',
'xmin',
'x2',
'x3',
'xmax',
'ymin',
'y2',
'y3',
'ymax',
'cross'] # only for BDN
})
def get_category_id(cls):
for category in dataset['categories']:
if category['name'] == cls:
return category['id']
# 遍历abcnet txt 标注
indexes = sorted([f.split('.')[0]
for f in os.listdir(abc_gt_dir)])
print(indexes)
j = 1 # 标注边框id号
for index in indexes:
# if int(index) >3: continue
# print('Processing: ' + index)
im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))
im_height, im_width = im.shape[:2]
dataset['images'].append({
'coco_url': '',
'date_captured': '',
'file_name': index + '.jpg',
'flickr_url': '',
'id': int(index.split('_')[-1]), # img_1
'license': 0,
'width': im_width,
'height': im_height
})
anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))
with open(anno_file) as f:
lines = [line for line in f.readlines() if line.strip()]
# 没有清晰的标注,跳过
if len(lines) <= 1:
continue
for i, line in enumerate(lines[1:]):
elements = line.strip().split(',')
polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32) # [14,(x,y)]
control_points = bezier_utils.polygon_to_bezier_pts(polygon, im) # [8,(x,y)]
ct = elements[-1].replace('"', '').strip()
cls = 'text'
# segs = [float(kkpart) for kkpart in parts[:16]]
segs = [float(kkpart) for kkpart in control_points.flatten()]
xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]
# 过滤越界边框
if max(xt) > im_width or max(yt) > im_height:
print('The annotation bounding box is outside of the image:{}'.format(index))
print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))
continue
xmin = min([xt[0], xt[3], xt[4], xt[7]])
ymin = min([yt[0], yt[3], yt[4], yt[7]])
xmax = max([xt[0], xt[3], xt[4], xt[7]])
ymax = max([yt[0], yt[3], yt[4], yt[7]])
width = max(0, xmax - xmin + 1)
height = max(0, ymax - ymin + 1)
if width == 0 or height == 0:
continue
max_len = 100
recs = [len(cV2) + 1 for ir in range(max_len)]
ct = str(ct)
# print('rec', ct)
for ix, ict in enumerate(ct):
if ix >= max_len:
continue
if ict in cV2:
recs[ix] = cV2.index(ict)
else:
recs[ix] = len(cV2)
dataset['annotations'].append({
'area': width * height,
'bbox': [xmin, ymin, width, height],
'category_id': get_category_id(cls),
'id': j,
'image_id': int(index.split('_')[-1]), # img_1
'iscrowd': 0,
'bezier_pts': segs,
'rec': recs
})
j += 1
# 写入json文件
folder = os.path.dirname(abc_json_path)
if not os.path.exists(folder):
os.makedirs(folder)
with open(abc_json_path, 'w') as f:
json.dump(dataset, f)
def main(args):
gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)
if __name__ == '__main__':
"""
Usage: python convert_ann_to_json.py \
--ann-dir /path/to/gt \
--image-dir /path/to/image \
--dst-json-path train.json
"""
parse = argparse.ArgumentParser()
parse.add_argument("--ann-dir", type=str, default=None)
parse.add_argument("--image-dir", type=str, default=None)
parse.add_argument("--dst-json-path", type=str, default=None)
parse.add_argument("--classes-path", type=str, default='./classes.txt')
arguments = parse.parse_args() # sys.argv[1:]
main(arguments)
方法二:
将labelme标注转为windows_label_tool标注格式,然后执行方法一的json转换。
labelme标注的格式转为windows_label_tool:
1. labelme 标注的json文件标注转abcnet 的gt标注,如果直接使用windowlabel工具标注则可省去此步骤
# coding=utf-8
# labelme 标注的json文件标注转abcnet 的标注,如果直接使用windowlabel工具标注则可省去此步骤
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import interpolate
from scipy.special import comb as n_over_k
import glob, os
import cv2
from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean
import json
import matplotlib.pyplot as plt
import math
import numpy as np
import random
import torch
from torch import nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score
from shapely.geometry import *
from PIL import Image
import time
import math
import re
class Bezier(nn.Module):
def __init__(self, ps, ctps):
"""
ps: numpy array of points
"""
super(Bezier, self).__init__()
self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
self.x0 = ps[0, 0]
self.x3 = ps[-1, 0]
self.y0 = ps[0, 1]
self.y3 = ps[-1, 1]
self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
self.t = torch.as_tensor(np.linspace(0, 1, 81))
def forward(self):
x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
t = self.t
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
bezier = torch.stack((bezier_x, bezier_y), dim=1)
diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
sdiffs = diffs ** 2
dists = sdiffs.sum(dim=2).sqrt()
min_dists, min_inds = dists.min(dim=1)
return min_dists.sum()
def control_points(self):
return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3
def control_points_f(self):
return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3
def train(x, y, ctps, lr):
x, y = np.array(x), np.array(y)
ps = np.vstack((x, y)).transpose()
bezier = Bezier(ps, ctps)
return bezier.control_points_f()
def draw(ps, control_points, t):
x = ps[:, 0]
y = ps[:, 1]
x0, x1, x2, x3, y0, y1, y2, y3 = control_points
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x,y,color='m',linestyle='',marker='.')
bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
plt.plot(bezier_x,bezier_y, 'g-')
plt.draw()
plt.pause(1) # <-------
raw_input("<Hit Enter To Close>")
plt.close(fig)
Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]
def bezier_fit(x, y):
dy = y[1:] - y[:-1]
dx = x[1:] - x[:-1]
dt = (dx ** 2 + dy ** 2)**0.5
t = dt/dt.sum()
t = np.hstack(([0], t))
t = t.cumsum()
data = np.column_stack((x, y))
Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2)
medi_ctp = control_points[1:-1,:].flatten().tolist()
return medi_ctp
def bezier_fitv2(x, y):
xc01 = (2*x[0] + x[-1])/3.0
yc01 = (2*y[0] + y[-1])/3.0
xc02 = (x[0] + 2* x[-1])/3.0
yc02 = (y[0] + 2* y[-1])/3.0
control_points = [xc01,yc01,xc02,yc02]
return control_points
def is_close_to_line(xs, ys, thres):
regression_model = LinearRegression()
# Fit the data(train the model)
regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
# Predict
y_predicted = regression_model.predict(xs.reshape(-1,1))
# model evaluation
rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2
if rmse > thres:
return 0.0
else:
return 2.0
def is_close_to_linev2(xs, ys, size, thres = 0.05):
pts = []
nor_pixel = int(size**0.5)
for i in range(len(xs)):
pts.append(Point([xs[i], ys[i]]))
import itertools
# iterate by pairs of points
slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
diffs = abs(slopes - st_slope)
score = diffs.sum() * max_dis/nor_pixel
if score < thres:
return 0.0
else:
return 3.0
labels = glob.glob("dataset/json/*.json")
labels.sort()
if not os.path.isdir('abcnet_gen_labels'):
os.mkdir('abcnet_gen_labels')
for il, label in enumerate(labels):
print('Processing: '+label)
imgdir = label.replace('json/', 'image/').replace('.json', '.jpg')
outgt = open(label.replace('dataset/json/', 'abcnet_gen_labels/').replace('.json', '.txt'), 'w')
data = []
cts = []
with open(label,"r") as f:
jdata = json.loads(f.read())
boxes = jdata["shapes"]
for il ,box in enumerate(boxes):
line,ct = box["points"],box["label"]
pts =[]
[pts.extend(p) for p in line]
if len(line) == 4:
pts = line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2] +[(line[2][0]+line[3][0])/2, (line[2][1]+line[3][1])/2]+ line[3]
if len(line) == 6:
if abs(line[0][0] - line[1][0]) > abs(line[1][0] - line[2][0]):
pts= line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2]
pts += line[3] + [(line[3][0]+line[4][0])//2, (line[3][1]+line[4][1])//2] + line[4] + line[5]
else:
pts = line[0] + line[1] + [(line[1][0]+line[2][0])//2, (line[1][1]+line[2][1])//2] + line[2]
pts += line[3] + line[4] + [(line[4][0]+line[5][0])//2, (line[4][1]+line[5][1])//2] + line[5]
data.append(np.array([float(x) for x in pts]))
cts.append(ct)
############## top
img = plt.imread(imgdir)
for iid, ddata in enumerate(data):
lh = len(data[iid])
assert(lh % 4 ==0)
lhc2 = int(lh/2)
lhc4 = int(lh/4)
xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)
left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]
x_data = curve_data_top[:, 0]
y_data = curve_data_top[:, 1]
init_control_points = bezier_fit(x_data, y_data)
learning_rate = is_close_to_linev2(x_data, y_data, img.size)
x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, learning_rate)
control_points = np.array([
[x0,y0],\
[x1,y1],\
[x2,y2],\
[x3,y3]
])
x_data_b = curve_data_bottom[:, 0]
y_data_b = curve_data_bottom[:, 1]
init_control_points_b = bezier_fit(x_data_b, y_data_b)
learning_rate = is_close_to_linev2(x_data_b, y_data_b, img.size)
x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, learning_rate)
control_points_b = np.array([
[x0_b,y0_b],\
[x1_b,y1_b],\
[x2_b,y2_b],\
[x3_b,y3_b]
])
t_plot = np.linspace(0, 1, 81)
Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)
Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)
plt.plot(Bezier_top[:,0], Bezier_top[:,1], 'g-', label='fit', linewidth=1)
plt.plot(Bezier_bottom[:,0],Bezier_bottom[:,1],'g-', label='fit', linewidth=1)
plt.plot(control_points[:,0],control_points[:,1], 'r.:', fillstyle='none', linewidth=1)
plt.plot(control_points_b[:,0],control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1)
plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1)
plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1)
outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}||||{}\n'.format(round(x0,2),round(y0,2),\
round(x1, 2), round(y1, 2),\
round(x2, 2), round(y2, 2),\
round(x3, 2), round(y3, 2),\
round(x0_b, 2), round(y0_b, 2),\
round(x1_b, 2), round(y1_b, 2),\
round(x2_b, 2), round(y2_b, 2),\
round(x3_b, 2), round(y3_b, 2),\
cts[iid])
outgt.writelines(outstr)
outgt.close()
plt.imshow(img)
plt.axis('off')
if not os.path.isdir('abcnet_vis'):
os.mkdir('abcnet_vis')
plt.savefig('abcnet_vis/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
plt.clf()
2.abcnet 的标注转abcnet的json
修改配置
- 修改相关配置文件进行训练
- 将制作好的data数据目录放在"AdelaiDet/datasets"目录
- 修改"adet/data/builtin.py"中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
_PREDEFINED_SPLITS_TEXT = { "totaltext_train": ("totaltext/train_images", "totaltext/train.json"), "totaltext_val": ("totaltext/test_images", "totaltext/test.json"), ... "abcnet_train": ("data/train", "data/annotations/train.json"), "abcnet_test": ("data/test", "data/annotations/test.json"),}
- 在需要训练的配置文件中指定数据集即可.以
configs/BAText/CTW1500/Base-CTW1500.yaml
为例,DATASETS: # detail cfg: AdelaiDet/adet/data/builtin.py TRAIN: ("abcnet_train",) TEST: ("abcnet_test",)
训练脚本如下:
# train custom
#1. Pretrainining with synthetic data:
OMP_NUM_THREADS=1 python tools/train_net.py \
--config-file configs/BAText/Pretrain/attn_R_50.yaml \
--num-gpus 4 \
OUTPUT_DIR text_pretraining/attn_R_50
#2. Finetuning
OMP_NUM_THREADS=1 python tools/train_net.py \
--config-file configs/BAText/CTW1500/attn_R_50.yaml \
--num-gpus 4 \
MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
eg:
# 1.
OMP_NUM_THREADS=1 python tools/train_net.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1
# 2.Finetuning on CTW1500:
OMP_NUM_THREADS=1 python tools/train_net.py \
--config-file configs/BAText/CTW1500/attn_R_50.yaml \
--num-gpus 1 \
MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
eg:
OMP_NUM_THREADS=1 python tools/train_net.py --config-file \
configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1 MODEL.WEIGHTS \
output/batext/ctw1500/attn_R_50/model_final.pth
# eval:
python tools/train_net.py \
--config-file configs/BAText/CTW1500/attn_R_50.yaml \
--eval-only \
MODEL.WEIGHTS ctw1500_attn_R_50.pth
test:
python demo/demo.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS output/batext/ctw1500/attn_R_50/model_0119999.pth
python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
以上针对英文和数字进行训练,后面补充针对中文的训练和修改等操作。
基于ctw的训练后,测试结果如下:
看最后一张,漏检为0了,效果上好于TOTALTEX训练的结果。这里我只是训练了
120000 + 120000 step.