ABCNet Bezier Align模块编译
在做场景文字识别时,经常需要对任意形状的文本进行检测识别,这一直以来是场景文字识别的一个难点,也不断有新的算法出现来解决这个问题。近期的如ABCNet就是众多方法之一。它主要采用参数化的贝塞尔曲线来自适应地处理任意形状的文本并增加了一层bazier align来提取任意形状文本实例的精确卷积特征。其算法框架如下图所示:
详细介绍可以参考网上相关教程,其中三次贝塞尔曲线拟合图如下所示:
当我们已经有一个算法模型能够检测出文本框bbox和详细的8个点的时候就可以考虑将该算法的bazier align移植过来进行对齐矫正。ABCNet csrc目录下包含了该方法的实现,查看源码发现该工程是使用pytorch中已有的算子采用c++、cuda来实现这部分功能的。
源码整理
我们将这部分单独拿出来编译,去掉不需要的部分,整个工程目录如下:
|
|-- csrc
| |-- BezierAlign.h
| |-- cuda
| | |-- BezierAlign_cuda.cu
| | `-- vision.h
| `-- vision.cpp
|-- setup.py
其中version.h内容如下:
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include <torch/extension.h>
at::Tensor BezierAlign_forward_cuda(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
at::Tensor BezierAlign_backward_cuda(const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio);
version.cpp 如下:
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include "BezierAlign.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bezier_align_forward", &BezierAlign_forward, "BezierAlign_forward");
m.def("bezier_align_backward", &BezierAlign_backward, "BezierAlign_backward");
}
其中setup.py如下所示:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#!/usr/bin/env python
import glob
import os
import torch
from setuptools import find_packages
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "csrc")
sources = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
sources = [os.path.join(extensions_dir, s) for s in sources]
print(f"sources sources:{sources}")
include_dirs = [extensions_dir]
ext_modules = [
extension(
"bezier",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="bazierAlign",
version="1.0.0",
author="xxxxx",
url="url url",
description="bazierAlign",
ext_modules=get_extensions(),
cmdclass={
"build_ext": torch.utils.cpp_extension.BuildExtension
},
)
终端执行
python3 setup.py build_ext
验证
import torch
import bezier
print(dir(bezier))
# 输出结果为
# ['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__','bezier_align_backward', 'bezier_align_forward']
调用封装部分 bazier_align.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import bezier as B
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from apex import amp
class _BezierAlign(Function):
@staticmethod
def forward(ctx, input, bezier, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(bezier)
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
output = B.bezier_align_forward(input, bezier, spatial_scale, output_size[0], output_size[1], sampling_ratio)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
beziers, = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
grad_input = B.bezier_align_backward(
grad_output,
beziers,
spatial_scale,
output_size[0],
output_size[1],
bs,
ch,
h,
w,
sampling_ratio,
)
return grad_input, None, None, None, None
bezier_align = _BezierAlign.apply
class BezierAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio):
"""[summary]
Args:
output_size ([type]): [输出图片大小]
spatial_scale ([type]): [description]
sampling_ratio ([type]): [description]
"""
super(BezierAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
@amp.float_function
def forward(self, input, beziers):
return bezier_align(input, beziers, self.output_size, self.spatial_scale, self.sampling_ratio)
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ")"
return tmpstr
测试
from PIL import Image, ImageOps
import numpy as np
import json
import cv2
import copy
import torch
from torch import nn
from bazier_align import BezierAlign
colors = [
(0, 0, 255),
(0, 255, 0),
(255, 0, 0),
(0, 0, 255),
(0, 255, 0),
(0, 255, 255),
(0, 255, 255),
(0, 255, 0),
]
def cat(tensors, dim=0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
class MyAlign(nn.Module):
"""
bazier align model
"""
def __init__(self, input_size, output_size, scale):
super(MyAlign, self).__init__()
self.bezier_align = BezierAlign(output_size, scale, 1)
self.masks = nn.Parameter(torch.ones(input_size, dtype=torch.float32))
def forward(self, input, rois):
# apply mask
x = input * self.masks
rois = self.convert_to_roi_format(rois)
return self.bezier_align(x, rois)
def convert_to_roi_format(self, beziers):
concat_boxes = cat([b for b in beziers], dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = cat(
[torch.full((len(b), 1), i, dtype=dtype, device=device) for i, b in enumerate(beziers)],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
def get_size(image_size, w, h):
w_ratio = w / image_size[1]
h_ratio = h / image_size[0]
down_scale = max(w_ratio, h_ratio)
if down_scale > 1:
return down_scale
else:
return 1
def test(scale=1, image_size=(1024, 1024), output_size=(64, 256)):
"""
"""
input_size = (image_size[0] // scale, image_size[1] // scale)
print(f"input_size:{input_size} output_size:{output_size} scale:{scale}")
my_align = MyAlign(input_size, output_size, 1 / scale).cuda()
beziers = [[]]
im_arrs = []
down_scales = []
imgfile = 'cutting3_00018.jpg'
# imgfile = '3.jpg'
im = Image.open(imgfile)
print(f"read image size:{im.size}")
w, h = im.size
down_scale = get_size(image_size, w, h)
down_scales.append(down_scale)
if down_scale > 1:
im = im.resize((int(w / down_scale), int(h / down_scale)), Image.ANTIALIAS)
w, h = im.size
padding = (0, 0, image_size[1] - w, image_size[0] - h)
im = ImageOps.expand(im, padding)
im = im.resize((input_size[1], input_size[0]), Image.ANTIALIAS)
im_arrs.append(np.array(im))
"""
object predict 8 points order:
p1----p2----p3----p4
| |
| |
p8----p7----p6-----p5
bazier align order
p1----p2----p3----p4
| |
| |
p5----p6----p7-----p8
"""
# p1, p2, p3, p4, p5, p6, p7, p8
cps = [
712.0165405273438,
255.10353088378906,
789.5339965820312,
334.0591125488281,
834.0347900390625,
454.64581298828125,
834.0347900390625,
575.2326049804688,
782.3565063476562,
575.2326049804688,
780.9209594726562,
466.1302795410156,
740.7266845703125,
364.2057800292969,
677.5643310546875,
301.04132080078125,
]
cps = np.array(cps)[[1, 0, 3, 2, 5, 4, 7, 6, 15, 14, 13, 12, 11, 10, 9, 8]]
beziers[0].append(cps)
beziers = [torch.from_numpy(np.stack(b)).cuda().float() for b in beziers]
beziers = [b / d for b, d in zip(beziers, down_scales)]
print(f"beziers:{beziers}")
im_arrs = np.stack(im_arrs)
x = torch.from_numpy(im_arrs).permute(0, 3, 1, 2).cuda().float()
x = my_align(x, beziers)
for i, roi in enumerate(x):
roi = roi.cpu().detach().numpy().transpose(1, 2, 0).astype(np.uint8)
im = Image.fromarray(roi, "RGB")
im.save(f'aligned_{imgfile.split(".")[0]}.png')
if __name__ == '__main__':
test()