实用笔记系列5

ABCNet Bezier Align模块编译

在做场景文字识别时,经常需要对任意形状的文本进行检测识别,这一直以来是场景文字识别的一个难点,也不断有新的算法出现来解决这个问题。近期的如ABCNet就是众多方法之一。它主要采用参数化的贝塞尔曲线来自适应地处理任意形状的文本并增加了一层bazier align来提取任意形状文本实例的精确卷积特征。其算法框架如下图所示:
在这里插入图片描述
详细介绍可以参考网上相关教程,其中三次贝塞尔曲线拟合图如下所示:
在这里插入图片描述
当我们已经有一个算法模型能够检测出文本框bbox和详细的8个点的时候就可以考虑将该算法的bazier align移植过来进行对齐矫正。ABCNet csrc目录下包含了该方法的实现,查看源码发现该工程是使用pytorch中已有的算子采用c++、cuda来实现这部分功能的。

源码整理

我们将这部分单独拿出来编译,去掉不需要的部分,整个工程目录如下:

|                  
|-- csrc
|   |-- BezierAlign.h
|   |-- cuda
|   |   |-- BezierAlign_cuda.cu
|   |   `-- vision.h
|   `-- vision.cpp
|-- setup.py

其中version.h内容如下:

// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include <torch/extension.h>

at::Tensor BezierAlign_forward_cuda(const at::Tensor& input,
                                 const at::Tensor& rois,
                                 const float spatial_scale,
                                 const int pooled_height,
                                 const int pooled_width,
                                 const int sampling_ratio);

at::Tensor BezierAlign_backward_cuda(const at::Tensor& grad,
                                  const at::Tensor& rois,
                                  const float spatial_scale,
                                  const int pooled_height,
                                  const int pooled_width,
                                  const int batch_size,
                                  const int channels,
                                  const int height,
                                  const int width,
                                  const int sampling_ratio);

version.cpp 如下:

// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include "BezierAlign.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("bezier_align_forward", &BezierAlign_forward, "BezierAlign_forward");
  m.def("bezier_align_backward", &BezierAlign_backward, "BezierAlign_backward");
}

其中setup.py如下所示:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#!/usr/bin/env python

import glob
import os

import torch
from setuptools import find_packages
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension

requirements = ["torch", "torchvision"]


def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, "csrc")

    sources = glob.glob(os.path.join(extensions_dir, "*.cpp"))
    source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
    
    extension = CppExtension

    extra_compile_args = {"cxx": []}
    define_macros = []

    if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
        extension = CUDAExtension
        sources += source_cuda
        define_macros += [("WITH_CUDA", None)]
        extra_compile_args["nvcc"] = [
            "-DCUDA_HAS_FP16=1",
            "-D__CUDA_NO_HALF_OPERATORS__",
            "-D__CUDA_NO_HALF_CONVERSIONS__",
            "-D__CUDA_NO_HALF2_OPERATORS__",
        ]

    sources = [os.path.join(extensions_dir, s) for s in sources]
    print(f"sources sources:{sources}")
    
    include_dirs = [extensions_dir]

    ext_modules = [
        extension(
            "bezier",
            sources,
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    ]

    return ext_modules


setup(
    name="bazierAlign",
    version="1.0.0",
    author="xxxxx",
    url="url url",
    description="bazierAlign",
    ext_modules=get_extensions(),
    cmdclass={
        "build_ext": torch.utils.cpp_extension.BuildExtension
    },
)

终端执行

python3 setup.py build_ext

验证

import torch
import bezier
print(dir(bezier))
# 输出结果为
# ['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__','bezier_align_backward', 'bezier_align_forward']

调用封装部分 bazier_align.py

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import bezier as B

from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from apex import amp

class _BezierAlign(Function):
    @staticmethod
    def forward(ctx, input, bezier, output_size, spatial_scale, sampling_ratio):
        ctx.save_for_backward(bezier)
        ctx.output_size = _pair(output_size)
        ctx.spatial_scale = spatial_scale
        ctx.sampling_ratio = sampling_ratio
        ctx.input_shape = input.size()
        output = B.bezier_align_forward(input, bezier, spatial_scale, output_size[0], output_size[1], sampling_ratio)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        beziers, = ctx.saved_tensors
        output_size = ctx.output_size
        spatial_scale = ctx.spatial_scale
        sampling_ratio = ctx.sampling_ratio
        bs, ch, h, w = ctx.input_shape
        grad_input = B.bezier_align_backward(
            grad_output,
            beziers,
            spatial_scale,
            output_size[0],
            output_size[1],
            bs,
            ch,
            h,
            w,
            sampling_ratio,
        )
        return grad_input, None, None, None, None


bezier_align = _BezierAlign.apply


class BezierAlign(nn.Module):
    def __init__(self, output_size, spatial_scale, sampling_ratio):
        """[summary]

        Args:
            output_size ([type]): [输出图片大小]
            spatial_scale ([type]): [description]
            sampling_ratio ([type]): [description]
        """
        super(BezierAlign, self).__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio

    @amp.float_function
    def forward(self, input, beziers):
        return bezier_align(input, beziers, self.output_size, self.spatial_scale, self.sampling_ratio)

    def __repr__(self):
        tmpstr = self.__class__.__name__ + "("
        tmpstr += "output_size=" + str(self.output_size)
        tmpstr += ", spatial_scale=" + str(self.spatial_scale)
        tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
        tmpstr += ")"
        return tmpstr

测试

from PIL import Image, ImageOps
import numpy as np
import json
import cv2
import copy
import torch
from torch import nn

from bazier_align import BezierAlign

colors = [
    (0, 0, 255),
    (0, 255, 0),
    (255, 0, 0),
    (0, 0, 255),
    (0, 255, 0),
    (0, 255, 255),
    (0, 255, 255),
    (0, 255, 0),
]


def cat(tensors, dim=0):
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
    assert isinstance(tensors, (list, tuple))
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


class MyAlign(nn.Module):
    """
    bazier align model
    """
    def __init__(self, input_size, output_size, scale):
        super(MyAlign, self).__init__()
        self.bezier_align = BezierAlign(output_size, scale, 1)
        self.masks = nn.Parameter(torch.ones(input_size, dtype=torch.float32))

    def forward(self, input, rois):
        # apply mask
        x = input * self.masks
        rois = self.convert_to_roi_format(rois)
        return self.bezier_align(x, rois)

    def convert_to_roi_format(self, beziers):
        concat_boxes = cat([b for b in beziers], dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = cat(
            [torch.full((len(b), 1), i, dtype=dtype, device=device) for i, b in enumerate(beziers)],
            dim=0,
        )
        
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois


def get_size(image_size, w, h):
    w_ratio = w / image_size[1]
    h_ratio = h / image_size[0]
    down_scale = max(w_ratio, h_ratio)
    if down_scale > 1:
        return down_scale
    else:
        return 1


def test(scale=1, image_size=(1024, 1024), output_size=(64, 256)):
    """
    """
    input_size = (image_size[0] // scale, image_size[1] // scale)
    print(f"input_size:{input_size} output_size:{output_size} scale:{scale}")
    my_align = MyAlign(input_size, output_size, 1 / scale).cuda()

    beziers = [[]]
    im_arrs = []
    down_scales = []

    imgfile = 'cutting3_00018.jpg'
    #     imgfile = '3.jpg'
    im = Image.open(imgfile)
    print(f"read image size:{im.size}")
    w, h = im.size
    down_scale = get_size(image_size, w, h)
    down_scales.append(down_scale)
    if down_scale > 1:
        im = im.resize((int(w / down_scale), int(h / down_scale)), Image.ANTIALIAS)
        w, h = im.size
    padding = (0, 0, image_size[1] - w, image_size[0] - h)
    im = ImageOps.expand(im, padding)
    im = im.resize((input_size[1], input_size[0]), Image.ANTIALIAS)
    im_arrs.append(np.array(im))
    """
    object predict 8 points order:
        p1----p2----p3----p4
        |                  |
        |                  |
        p8----p7----p6-----p5
    bazier align order
        p1----p2----p3----p4
        |                  |
        |                  |
        p5----p6----p7-----p8
    """
    # p1, p2, p3, p4, p5, p6, p7, p8
    cps = [
        712.0165405273438,
        255.10353088378906,
        789.5339965820312,
        334.0591125488281,
        834.0347900390625,
        454.64581298828125,
        834.0347900390625,
        575.2326049804688,
        782.3565063476562,
        575.2326049804688,
        780.9209594726562,
        466.1302795410156,
        740.7266845703125,
        364.2057800292969,
        677.5643310546875,
        301.04132080078125,
    ]
    cps = np.array(cps)[[1, 0, 3, 2, 5, 4, 7, 6, 15, 14, 13, 12, 11, 10, 9, 8]]
    beziers[0].append(cps)
    beziers = [torch.from_numpy(np.stack(b)).cuda().float() for b in beziers]
    beziers = [b / d for b, d in zip(beziers, down_scales)]
    print(f"beziers:{beziers}")

    im_arrs = np.stack(im_arrs)
    x = torch.from_numpy(im_arrs).permute(0, 3, 1, 2).cuda().float()

    x = my_align(x, beziers)
    for i, roi in enumerate(x):
        roi = roi.cpu().detach().numpy().transpose(1, 2, 0).astype(np.uint8)
        im = Image.fromarray(roi, "RGB")
        im.save(f'aligned_{imgfile.split(".")[0]}.png')


if __name__ == '__main__':
    test()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

血_影

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值